# Flattened Rust Source for LLM Ingestion - **Repository root:** `/root/mrpf` - **Generated (UTC):** `2025-11-17T23:00:04Z` - **File count:** 263 > Notes: This document contains only Rust source files (*.rs) and Markdown files (*.md). > Directories such as .git/, target/, node_modules/, build/, and book/build are excluded. ## File Tree (Rust files only) .github/ copilot-instructions.md instructions/ mrpf_models.instructions.md ACTIELIJST.md README.md book/ src/ AI/ README.md Database model IDEAS TO BE CATEGORIZED INTO mdBook.md README.md SUMMARY.md api/ README.md apple_universal_app/ README.md mrpf_scanner_api.md cert_transparency_records/ README.md TODO.md coding_guidelines/ README.md error_handling.md contributing.md core_principles.md mpf_previous_iteration/ README.md what_works_and_what_not.md network_engine/ README.md dns_resolver.md http11_scanner.md ideas/ MY_FIRST_ORIGINAL_IDEA_FOR_BETTER_THROUGHPUT.md README.md TLS_parse_improvements.md improvements_to_investigate_on_linux.md recucing_pselect_calls.md refactor_network_stack.md tls_parse_improvements.md infrastructure.md os_network_stack_challenges.md tcp_syn_scanner.md tls_scraper.md whois_resolver.md security_ideas/ README.md ai_wordlists.md aws_protections.md aws_web_identity.md random_ideas.md s3_dns_recon.md system_architecture/ README.md task_manager/ README.md data.md examples.md infrastructure.md initial_description_whilst_building.md task_generators.md tasks_and_containers.md templates.md workers.md clients/ mrpf_api_client/ src/ base.rs domains.rs lib.rs mrpf_app/ README.md mrpf_cli/ README.md src/ main.rs progress/ dns.rs mod.rs tcpsyn.rs tls.rs infrastructure/ api/ README.md src/ api/ src/ error.rs handlers/ apn_devices.rs domains.rs ips.rs jobs.rs mod.rs notifications.rs targets.rs template_keys.rs word_categories.rs words.rs main.rs utils/ apn_notifications.rs apn_token_provider.rs mod.rs bare_metal_worker/ OPTIMIZELINUX.md README.md db_and_network/ README.md task_manager/ README.md src/ manager/ src/ main.rs worker/ src/ main.rs promptios.md scanners/ mrpf_dns_resolver/ DNS_RECON_TRICKS.md README.md src/ lib.rs progress.rs receive.rs scanner.rs transmit.rs mrpf_engine/ benches/ intranges.rs optimize_ip_calc.rs src/ config.rs connections/ mod.rs tcp.rs lib.rs macros.rs models/ feistel.rs ipv4_address_ranges.rs ipv4_range.rs mod.rs port_range.rs port_ranges.rs network/ arp.rs datalink.rs dns.rs ethernet.rs http.rs ip.rs mod.rs tcp.rs tls.rs progress.rs ratelimit.rs utils.rs mrpf_http1_scanner/ FCK.md Problem.md src/ Howto.md connections.rs lib.rs models/ http1request.rs http1response.rs mod.rs progress.rs receive.rs scanner.rs template.rs tls.rs transmit.rs mrpf_tcp_syn_scanner/ src/ lib.rs progress.rs receive.rs scanner.rs transmit.rs mrpf_tls_scanner/ src/ how_to_manage_snis.md lib.rs progress.rs receive.rs scanner.rs targets.rs transmit.rs mrpf_whois/ README.md src/ main.rs record.rs service_roles_security.md services/ mrpf_bare_metal_worker/ src/ cap_validation.rs main.rs shutdown/ linux.rs mod.rs other.rs tracing_setup.rs worker.rs mrpf_cert_transparency/ src/ controllers/ certclient.rs database.rs mod.rs lib.rs main.rs models/ logserver.rs logserverentry.rs logserverrange.rs merkletree.rs mod.rs retryconfig.rs utils/ extractcert.rs fqdn.rs mod.rs transformations.rs mrpf_scanner_api/ README.md src/ main.rs messages.rs progress/ https.rs mod.rs tcpsyn.rs tls.rs mrpf_task_worker/ src/ handler.rs lib.rs shared/ lambda_helpers/ src/ db_details.rs lib.rs path_params.rs query_params.rs tracing.rs mrpf_cert_parser/ src/ error.rs fqdn.rs lib.rs parser.rs tags.rs tlv.rs mrpf_core/ sql/ initial.md src/ db/ client.rs mod.rs recon/ apn_device.rs domain.rs ipv4address.rs jobs.rs mod.rs notification.rs target.rs words.rs tasks/ data.rs mod.rs tasks.rs error/ base.rs mod.rs lib.rs recon/ apn_device.rs domain.rs ipv4address.rs job.rs job_history.rs mod.rs notification.rs soa_record.rs target.rs template_key.rs word.rs word_category.rs wordlist.rs tasks/ definitions/ conditional.rs dns_scan_tasks/ dns_scan.rs dns_scan_aggregator.rs dns_scan_generator.rs mod.rs filter.rs insert_data.rs notifications/ error.rs generic.rs mod.rs task_result.rs recon_api_tasks/ list_models.rs mod.rs model_type.rs upsert_models.rs tcp_syn_tasks/ mod.rs tcp_syn.rs tcp_syn_aggregator.rs tcp_syn_generator.rs tls_scan_tasks/ mod.rs tls_scan.rs tls_scan_aggregator.rs tls_scan_generator.rs execution_context.rs manager/ handler.rs mod.rs models/ completion_policy.rs messages.rs task.rs task_collection.rs task_container.rs task_definition.rs task_kind.rs task_result.rs task_state.rs worker_requirements.rs queue_client.rs templates/ data_kind.rs mod.rs parser.rs template.rs mrpf_matchers/ src/ lib.rs matchers/ genericmatcher.rs mod.rs rangematcher.rs wordmatcher.rs ## Files ### `.github/copilot-instructions.md` **Absolute path:** `/root/mrpf/.github/copilot-instructions.md` ```markdown # Copilot Instructions for MRPF Recon Tooling This project is a high-performance Rust-based reconnaissance toolkit for bug bounty hunting, optimized for speed, low memory, and custom protocol research. AI agents should follow these codebase-specific conventions and workflows: ## Project Overview - **Language**: Rust (latest stable version, currently 1.90.0 as of 2025-09-14). - **Purpose**: High-performance reconnaissance tooling for bug bounty hunting, emphasizing speed, low memory usage, and custom protocol implementations. - **Target Platforms**: ARM architecture running Linux or macOS. - **Developer**: Single developer, allowing for aggressive refactoring and downtime when beneficial. - **Focus**: Performance-critical code with minimal external dependencies, custom network and protocol stacks, and the ability to bypass RFCs/standards for security research. ## Architecture & Major Components - **Custom Network Stack**: Implements Ethernet/IP/TCP at datalink level using `libpnet`, with custom packet crafting and parsing. See `mrpf_engine/` for core logic. - **Thread Model**: Three main threads per scan: Sending, Receiving, and Status Reporting. No shared state between send/receive; packet tracking uses TCP sequence/ack hashes. See `README.md` for mermaid diagrams and thread flow. - **Scan Types**: TCP SYN, TLS scraping (CN/SAN only), HTTP(S). Each scan type is a module, easily extensible. - **API Layer**: `mrpf_scanner_api/` exposes a WebSocket API for scan orchestration. Message formats and workflows are documented in its `README.md`. - **Models**: Shared types in `mrpf_models/`. ## Developer Workflows - **Build**: Use `cargo build`. For raw socket access: - Linux: `sudo setcap cap_net_raw+ep target/debug/mrpf` - macOS: Run as `sudo`, or use Wireshark's ChmodBPF for BPF device permissions (see `README.md` for details). - **Run**: Avoid `cargo run` for privileged binaries; build, set permissions, then run directly to preserve capabilities. - **Profiling**: Use `cargo flamegraph` or `cargo instruments` (macOS) for performance analysis. See `README.md` for flamegraph and sampling tips. - **Testing**: Unit tests are in each crate, run with `cargo test`. Performance benchmarks use `criterion` (see `mrpf_models/benches/`). ## Project-Specific Patterns - **Minimal Dependencies**: Prefer `std` over crates. Only use external crates (e.g., `pnet`, `byteorder`, `serde`) when essential for performance or protocol parsing. - **Memory/Perf**: Avoid heap allocations in hot paths; use stack/arena buffers and zero-copy parsing. Drop unused data ASAP (e.g., discard TLS bytes after extracting CN/SAN). - **Extensibility**: New scan types should follow the thread model and packet tracking conventions. Reference `mrpf_engine/src/` for examples. - **Error Handling**: Use `Result`/`Option` but avoid deep error chains in perf-critical code. Log only actionable errors. - **Cross-Component Communication**: Use message passing (channels) between threads; avoid global state. WebSocket API uses JSON messages (see `mrpf_scanner_api/README.md`). ## Integration & External Systems - **AWS/Fargate**: Designed for ARM Linux/macOS, but can run in containers with raw socket support. See `README.md` for deployment notes and cost analysis. - **Database/Models**: See `mrpf_models/` and `__random_notes/Database model.md` for schema and graph/relational discussions. ## Key MRPF Rust Libraries & Infrastructure - **mrpf_engine/** Core custom network stack and scan logic (Ethernet/IP/TCP, packet crafting/parsing). - **mrpf_scanner_api/** WebSocket API for scan orchestration. Message formats and workflows documented in its `README.md`. - **mrpf_models/** Shared types and database models. - **mrpf_core/** Scan orchestration, scheduling, and distributed task management. - **mrpf_cert_transparency/** Certificate transparency and TLS scraping logic. - **mrpf_api_client/** Client library for interacting with the scanner API. - **infrastructure/** Deployment scripts, container specs, CI/CD, and cloud integration (e.g., AWS/Fargate). - **__random_notes/** Design notes, database ideas, and performance experiments. **When in doubt, prefer performance, minimalism, and extensibility.** ## Coding Guidelines ### 1. Rust Version and Features - Always use the **latest stable Rust version** (e.g., 1.90.0 as of September 14, 2025). - Leverage modern Rust features (e.g., async/await, const generics, GATs) when they improve performance, reduce memory usage, or simplify code without adding complexity. - Avoid deprecated features; refactor code to use newer equivalents if they provide measurable benefits. - Example: Use `std::sync::Arc` for thread-safe sharing, but prefer stack allocation or `Rc` when threading is not required to minimize overhead. ### 2. Performance and Memory Optimization - Prioritize **speed and low memory usage** in all code. - Avoid dynamic allocations (e.g., `Vec`, `String`) unless necessary; prefer fixed-size arrays or buffers when possible. - Example: For TLS scraping, use a precomputed constant for cipher suites instead of dynamic computation. - Use stack-based or arena-based allocations for temporary data to reduce heap fragmentation. - Minimize copies by using references, slices, or zero-copy parsing (e.g., `nom` or manual byte parsing). - Profile and benchmark critical paths using tools like `criterion` or `perf` to identify bottlenecks. - Drop unnecessary data from memory as soon as possible (e.g., discard irrelevant TLS response bytes after extracting CN/SAN). ### 3. Minimal External Crates - Use as few external crates as possible to reduce dependency overhead and maintain control over the codebase. - Prefer standard library (`std`) modules for core functionality (e.g., `std::net` for networking, `std::sync` for concurrency). - Only include external crates when they provide significant performance or functionality benefits unavailable in `std`. - Example: Use `byteorder` for endianness handling if manual parsing is too complex, but avoid heavy frameworks like `tokio` or `hyper`. - When using external crates, pin to specific versions and audit for performance and security implications. ### 4. Custom Network and Protocol Stacks - Implement **custom network stacks** (e.g., TCP, HTTP/HTTPS, DNS) to bypass RFCs/standards for security research purposes. - Example: Custom TCP stack with separate send/receive threads and TCP SYN cookies for high-throughput connection tracking, inspired by `masscan`. - Parse only the necessary data from protocol responses to minimize memory and CPU usage. - Example: For TLS scraping, extract only the Common Name (CN) and Subject Alternative Names (SAN) from server responses, discarding all other data. - Avoid adhering to standards when it restricts security research (e.g., allow malformed packets or non-standard protocol behavior). - Use raw sockets or low-level APIs (e.g., `libc::socket`) for maximum control over network operations on Linux/macOS ARM. ### 5. Concurrency and Threading - Use lightweight threading models for high-throughput tasks (e.g., separate send/receive threads in the TCP stack). - Prefer `std::thread` over heavy async runtimes (e.g., `tokio`) unless async provides clear performance benefits for I/O-bound tasks. - Implement thread-safe data sharing only when necessary, using `std::sync` primitives like `Mutex` or `RwLock` sparingly to avoid contention. - Use TCP SYN cookies or similar techniques for efficient connection tracking in high-throughput scenarios. ### 6. Platform Compatibility - Ensure all code runs on **ARM architecture** with **Linux** or **macOS**. - Avoid platform-specific dependencies unless they are supported on both Linux and macOS (e.g., use `std::os::unix` for Unix-specific features). - Test code on both platforms to ensure portability, using conditional compilation (`#[cfg(...)]`) when necessary. - Optimize for ARM-specific features (e.g., NEON instructions) if they improve performance without sacrificing portability. ### 7. Refactoring and Maintenance - Aggressively refactor code when it improves performance, reduces memory usage, or simplifies maintenance. - Accept downtime for refactoring, as the project is maintained by a single developer. - Prioritize clear, maintainable code over premature abstractions; only introduce abstractions if they measurably improve flexibility or performance. - Document performance-critical sections with comments explaining trade-offs (e.g., why a constant cipher suite is used in TLS scraping). ### 8. Security Research Considerations - Design code to **break RFCs/standards** when needed for bug bounty research (e.g., sending non-standard packets or parsing partial responses). - Implement custom protocol stacks to allow flexibility in packet crafting and response handling. - Ensure tools can handle edge cases (e.g., malformed server responses) without crashing. ### 9. Error Handling - Use Rust’s `Result` and `Option` for robust error handling, but avoid excessive error propagation in performance-critical paths. - Implement a custom Error and Result type for each library crate to encapsulate common error scenarios. - Log errors minimally to avoid performance overhead; focus on actionable error messages for debugging. - Gracefully handle network errors (e.g., dropped connections) to ensure tools remain stable during scans. ### 10. Testing and Validation - Write unit tests for critical components (e.g., TCP stack, protocol parsers) using `#[cfg(test)]`. - Include performance benchmarks using `criterion` to validate optimizations. - Test on ARM-based Linux and macOS systems to ensure compatibility and performance. - Simulate non-standard protocol responses to verify robustness for security research. ``` ### `.github/instructions/mrpf_models.instructions.md` **Absolute path:** `/root/mrpf/.github/instructions/mrpf_models.instructions.md` ```markdown --- description: MRPF models coding standards applyTo: "mrpf_models/**/*.rs" --- - When implementing sqlx serialization, use `sqlx::FromRow` and place it behind the ```sqlx``` feature flag. ``` ### `ACTIELIJST.md` **Absolute path:** `/root/mrpf/ACTIELIJST.md` ```markdown # Actielijst voor vandaag Probeer deze lijst elke dag bij te werken en taken te verwijderen als ze klaar zijn. Je hoeft niet alles te bewaren hier, geeft je alleen wat focus voor de dag. - Rename TaskCollection in macApp to Job and JobHistory. ## Ideeën voor morgen - Ensure my list models function work - Ensure a reference to ${target.id} works through SwiftUI - Implement pagination for notifications in Apple App (AsyncStream perhaps?) - Implement killing bare_metal_worker instances if they exceed their time limit. - Implement DNS authoritive nameserver lookups and pass them onto DNS resolver tasks - Implement task manager finalization task. Collecteer basic statistics, schrijf weg naar statistics tabel en verwijder alle tijdelijke data. ### Deadline Ik geef mezelf 2 maanden vanaf nu om m'n tooling verder uit te bouwen. Op 7 December 2025 moet ik stoppen en proberen minstens 1 bug te vinden, gebruikmakend van mijn eigen tools. ## Wat moet ik doen voordat m'n nieuwe MRPF omgeving weer ongeveer on par is met de oude? - TaskCollection API calls + scheduler - Tasks for my TLS scanner, TcpSyn Scanner and DNS resolver (Lets not spend effort to rebuild crt.sh, we should focus on my own transparency records log) - Update MRPF Apple app to work with new API ## Ideeen voor later - Rename TaskCollection to ScheduledTaskCollection in code and database OR to Job? - Man it's useless but i REALLY want to have live activities working on my iPhone for any in progress TaskCollection. That would be so sick. - Use Iterator generics in my scanners for providing the targets. This way we can generalize the inputs better and provide different flavours of inputs. Eg. sometimes we want to do DNS resolving with the same DNS server, other times we have specific resolvers per target. Same with TLS scanning, sometimes we're fuzzing for a bunch of SNI's, other times we want to find a specific SNI on a range of ip/ports. something like Iterator>> - Ensure receive thread is stopped when the tls/tcp syn scanner is finalised. This will require me to open a channel to the receive thread to tell it to stop. We should have channels to the send and progress threads as well so we can signal other things. I think i though about this earlier when implementing my websocket interface. So things like stop, pause, resume all would be nice to be able to send. However, i would really prefer to have a bunch more deduplication in my scanner code first before adding more complexity. - Revisit engine and scanners to include tracing, introduce proper error/result class, and perhaps find a better abstraction to deduplicate some scanner code? - Read this blog to see if we can optimize some stuff using these ideas: . Perhaps move it to my mdBook - Update ipv4addess merge to allow for merging of ipv4addresses from different targets. We don't typically store ip addresses duplicates for different targets but we might want to have tasks that scan ip ranges for different targets. That means we need to treat target_id as part of the uniqueness of an ipv4address entry. This means we also need to update our merging logic to take that into account. - Update Task Manager process_tasks functie om als er een error is, de task_collection als gefailed te markeren en geen verdere tasks meer te starten. - Remove the functions completely from my templates. Instead we should create separate ETL tasks that can perform actions. Think of lowercase, uppercase, trim, replace, to_snake_case, to_camel_case, etc. Best would be to see if we can have a single Transformation task where we can provide different actions to perform on input data in a single pass. - Maak m'n task voorbeelden af, inclusief hoe conditionals en loops zullen werken. - Wellicht weg stappen van JSON voor API en Websocket calls naar een binary protocol zodat ik m'n wordlists en HTTP request/responses meer zuiver kan krijgen. Hacking is vaak mogelijk door misinterpretatie van encodings. Grok had het over [CBOR](https://cbor.io/). seems to be most mature. - Creeer TaskCollectionComplete event om alles te verwijderen en statistics te verzamelen. - Implementeer Timeout voor tasks. Verwijder updated_at en vervang met started_at en completed_at. denk na over created_at, liefst dit zelf doen in plaats van een trigger zodat alle tasks gemaakt door een task collection of generator task dezelfde timestamp hebben. WACHT FF, GEEN created_at implementeren want we hebben uuidv7? Kunnen we in rust UUIDv7 genereren met dezelde timestamp als we willen? - Schrijf blog over hoe ik notities en taken voor mezelf bijhoud nu met AI. ACTIELIJST.md opgesplits in vandaag, morgen en later. Brainstorm notities in markdown die ik NIET opsla maar gebruik om snel m'n gedachten te ordenen. AI vragen binnen en buiten m'n code editor om m'n gedachte verder te ordenen en code te schrijven. Rigoreus zijn in het verwijderen van oude notities, als het belangrijk genoeg is, komt het wel weer boven drijven. Notities die sterk genoeg zijn verplaatsen naar m'n mdBook als documentatie vorm. - Examine all Error:Other cases in task manager and see if we can make them more specific. - Certificate Transparency records, looking at DynamoDB, we're at 509,377,316 entries right now. We should take this into account when designing our PostgreSQL schema and indexing strategy. - Creeer tabel voor AWS Account ids. Alessandro heeft al 280 AWS account owned id's verzameld. Kan erg interessant zijn om dingen te vinden. Account id's die we weten van targets willen we er ook in hebben. - Schrijf blog over de verwarring van mensen tussen host, ip, domein, root paths, dns, TLS SNI etc. Should be called, what exactly is a host? With the subtitle, don’t let abstractions get in the way of the truth - M'n image builder pipeline is wellicht allemaal wel overkill. AWS EC2 minimal image alleen had prima geweest + de launch template. Wellicht moet ik de AMI image builder weggooien. Het enige is dat dit me wel wat flexibiliteit kan geven later + misschien kan ik de mrpf git repo pullen en compileren op de image zelf. Nu cross compilen duurt eeeerg lang. ``` ### `README.md` **Absolute path:** `/root/mrpf/README.md` ```markdown # My Rusty Pension Fund [![Build Status](https://github.com/thiezn/mrpf/actions/workflows/ci.yml/badge.svg)](https://github.com/thiezn/mrpf/actions/workflows/ci.yml) > **ℹ️ NOTE:** > You can find more detailed and updated documentation in the [book](https://mrpf.mortimer.nl). > The text below is the older `README.md` file, but should still have some useful information, especially around running the project and profiling. I want to write a flexible Rust library for bug bounty hunting. Its focus is to build a very fast and memory/cpu optimized scanner for various scenarios. The main initial focus are specialized scanners for: - TCP Syn port scanning - TLS certificate Subject Alt Name scraper - HTTP(S) scanning The Rust library will implement it's own custom network stack using libpnet to be able to optimize everything as much as possible, similar to the masscan service. The network stack should be easily extensible, adding new scan types in the future that leverage the custom network stack. The core concept of the network stack is that it uses libpnet in datalink layer and contruct my own ethernet, ip and tcp packets. It will be running in two separate threads, one for sending packets and another for receiving packets. Similar to masscan, synchronisation of the send and receive thread is avoided by using the tcp sequence and ack numbers with a hash to keep track of what sending packet corresponds to what received packet. Please describe a clear initial outline of the files, modules and traits used in the library. In addition, create various mermaid.js diagrams, a high level overview and some detailed diagrams for the various traits, that illustrates the architecture. ## Folder structure ```text mrpf/ ├── clients │ ├── mrpf_api_client # Client to interact with the API │ └── mrpf_cli # Command-line interface for the various MRPF services ├── infrastructure │ ├── db_and_network # Deploys AWS RDS PostgreSQL and VPC │ └── mrpf_api # The AWS Lambda function for the API │ └── task_manager # The AWS Lambda functions and queues for the task manager ├── scanners │ ├── mrpf_engine # The core engine that drives the scanners. Based on masscan │ ├── mrpf_dns # DNS scanner │ ├── mrpf_http1_scanner # HTTP/1.1 scanner │ ├── mrpf_tcp_syn_scanner # TCP SYN scanner │ ├── mrpf_tls_scanner # TLS scanner │ ├── mrpf_whois # WHOIS scanner ├── services │ ├── mrpf_cert_transparency # Certificate transparency record scraper │ ├── mrpf_scanner_api # Websocket API for the scanners. Used by the Apple universal app ├── shared ├── mrpf_matchers # Generic matching library for filtering data ├── mrpf_core # The core models and features of the system. Includes the generic database client, the recon models and the task manager models and definitions. ``` ## How to use install with ```cargo build``` Make sure to either use sudo or set the capabilities of the binary to allow it to send raw packets. ```sudo setcap cap_net_raw+ep target/debug/mrpf``` on macOS do ```sudo target/debug/mrpf``` ### Permissions On linux you can use .. to give the script access to raw sockets: ```sh sudo setcap cap_net_raw+ep target/debug/mrpf ``` on macOS you will have to use BPF permissions. We can leverage the way Wireshark does it, if it's installed it will do the following ```text Wireshark on macOS uses the `ChmodBPF` script to manage permissions for Berkeley Packet Filters (BPF) devices, allowing users to capture packets without needing `sudo`. Here's how it works: 1. **ChmodBPF Script**: This script changes the ownership and permissions of the BPF devices so that users in the `admin` group can access them. 2. **Installation**: - When you install Wireshark, it typically includes the `ChmodBPF` script and a launch daemon plist file (`org.wireshark.ChmodBPF.plist`). - The script is placed in `/Library/Application Support/Wireshark/ChmodBPF/`. - The plist file is placed in `/Library/LaunchDaemons/`. 3. **Permissions and Ownership**: - The script sets the BPF devices to be owned by `root` and assigns the `admin` group read and write permissions. - This allows members of the `admin` group to use BPF devices without needing `sudo`. 4. **Loading the Daemon**: - The launch daemon ensures that the `ChmodBPF` script runs at startup, applying the necessary permissions to the BPF devices. - You can manually load the daemon with: ```sh sudo launchctl load /Library/LaunchDaemons/org.wireshark.ChmodBPF.plist ``` 5. **Reboot**: After installation, a reboot ensures that the changes take effect. This setup allows Wireshark to capture packets without requiring `sudo` each time[1](https://andrewbaker.ninja/2023/01/14/macbook-fixing-the-wireshark-permissions-bug-you-dont-have-permission-to-capture-on-that-device/)[2](https://ask.wireshark.org/question/16508/chmodbpf-permission-denied-on-macos-10154-catalina/). ``` If you have this in place, all you need to do is change the group of the binaries to admin. If you build it with sudo cargo run before, you first need to remove the full target folder. Then you can change the binaries and do cargo build first (without sudo) ```sh rm -rf target/ cargo build sudo chown root:admin /Users/thiezn/Development/mrpf/target/debug/mrpf sudo chown root:admin /Users/thiezn/Development/mrpf/target/debug/mrpf_scanner_api ls -ltrah target/debug/mrpf* | grep "\-rwx" ``` The problem is that cargo run will rebuild things, causing it to overwrite the binary permissions again. So you need to do cargo build first, then change the permissions and then run the binary instead of doing cargo run :( Which needs sudo, so it's a catch-22. BUT if the engine doesnt require a rebuild, and you only make changes to scanner_api, the rebuild will only happen for scanner_api and it still is able to run without sudo. This is actually strange.. :( ## MacOS profiling MacOS has a profiler called Instruments. Linux has a common one called perf. For MacOS we could perhaps use the following crate: You can install with ```cargo install cargo-instruments``` You can see the templates it has by doing ```cargo instruments --list-templates``` Run for instance the allocations by doing ```cargo instruments --template "Allocations"``` ### Flamegraph profiler Run ```cargo flamegraph``` for a quick flamegraph. to profile dev do ```cargo flamegraph --dev``` Note that there's a big difference between dev and prod. For instance, when doing a dev flamegraph, a lot of memory was spent doing the tcp syn cookie creation. However, when doing the prod flamegraph, the whole method was not even visible. run samply which will generate a flamegraph and open it in your browser ```sh sudo samply record target/release/mrpf ``` ## False positives Some environments give a bunch of false positives, for instance scanning the volksbank ip range. Here's nmap explaining some things: ## Change hash to cityhash bindings for google's hash pure rust This is 32bits and supposedly quite fast. We should measure against our own hash. Actually looking at the benchmarks of [cityhasher itself](https://docs.rs/cityhasher/latest/cityhasher/#32-bit-hashing), [crc32fast](https://github.com/srijs/rust-crc32fast) is even faster as it uses SIMD instructions. ``` ### `book/src/AI/README.md` **Absolute path:** `/root/mrpf/book/src/AI/README.md` ```markdown # No AI being used here Apart from helping out with coding, I've not built any specific AI functionality into this project. Obviously there's a lot of potential: - Generate new wordlists based on targets. Initially I just want to use Apple's own models in swift to generate stuff there from the app so i can visually see whats going on, filter before commiting to db. - Extract paths, domains and parameters from javascript and html files. Should train my own model to output this into a structured format. This would be really great to have and good practice for building my own fine-tuned models. ``` ### `book/src/Database model IDEAS TO BE CATEGORIZED INTO mdBook.md` **Absolute path:** `/root/mrpf/book/src/Database model IDEAS TO BE CATEGORIZED INTO mdBook.md` ```markdown # Database model ONE THING THAT REALLY HELPS HERE IS TO THINK ABOUT FIRST PRINCIPLES!! I'VE BEEN MODELLING MY API AND MODELS A BIT TOO MUCH WITH HIGH LEVEL CONCEPTS IN MIND. THIS HAS LEAD TO SOME THINGS THAT DON'T 100% MATCH. Some examples of the old mpf - Domain objects are actually more resembling my new WebApp object. - Top Level Domains are actually zones. - WebApps are bound to one or more services - A service is a combination of an IP address, port, port type and protocol This document outlines the database model for the MRPF API. Did a bunch of querying and came up with I think a nice model for storage of my data. Note that I have asked to not include foreign keys as Aurora DSQL doesn't support them. I DID ask it to add triggers though as the create/update timestamp is just too good to pass up. Hopefully AWS will quickly introduce triggers in AWS DSQL. The only thing I'm not really happy with still is the WebApp part and its relation to fqdn/ip/port/ip_type ## Tree or graph, or relational? The model I've come up with is okish and tackles some of the earlier issues BUT I still can't really create the proper visualization that I would want to implement. Most of the updates are just against simple ip addresses or domains, but the data in itself resembles more of a tree or a graph. I need to decide how I want to represent it. Weblinks are more like a graph, as a website can link to a website that links back to it. A graph can be more complicated as it introduces potential circular dependencies. A tree avoids this and when drawing out my mermaid diagram, it does resemble a tree more than a graph. HMM But does this mean every child can only have 1 parent? yes thats what a tree means, you also have a Directed Acyclic Graph (DAG) which can have multiple parents but no cyclic relations. A DAG with the ltree postgresql module The DAG still doesn't allow me to express everything properly i think. A hierarchy for a webapp all the way through a CDN, origin and internal host should be linked to the web app. Not to something else? Perhaps I should use the schema I have below with the static relationships for fqdns to ip addresses etc. But then in the webapp have a list (or json blob to allow different types in the list?) to show the relation? ```rust struct WebApp { id: i32, path: String, target_id: i32, relationships: Vec, } Enum Relationship { Fqdn(Fqdn), IpAddress(IpAddress), Service(Service), } ``` The problem is how do we keep the relationship up to date? Will we manually craft them? Also, it it actually a tree as it can be associated with multiple ip addresses and domains, and fan out and in again. Ok Copilot calls this a Rooted Directed Acyclic Graph (RDAG). This is a DAG with a single root node parent. In this case the WebApp will be the single root node? I think wiki calls is a rooted graph. BUT it should be a Rooted Directed Graph as all relations go downward. Rooted directed graphs (also known as flow graphs) or . Pfff, this graph stuff is interesting. I actually think there are different graph types that you can overlay onto my model. - if you highlight a FQDN you want it to show: - All directly connected fqdns - All directly connected webapps - if you highlight a webapp you want it to show: - All directly connected fqdns - All directly connected ip addresses - All directly connected services - all related origin fqdns/ips - all related internal ips/fqdns ## TODO's ### Found on list FQDNs should have a list of IP/Port combinations where it was found on (or the ip/port combination should link to the fqdns). This way I can see for instance with TLS scaping where hosts where found, even if there's no DNS record for it. ### Services UPDATE: Ok, One part I'm missing is something I could call a service. A service consist of a ip address + port + port type (tcp, udp or both) + protocol. This is what I need to connect to a webapp. The web app should have two additional columns, an array of Host and array of SNI, OR should this be a composite of the Host and SNI? To be decided if i want to create a separate webapp entry for each host/sni combination or not. The benefit is a more flat structure, the downside is that I need to create a new webapp entry for each host/sni combination. I think I'll go with the flat structure for now. ### Zones I think I should go back to the idea of having a zone table. This will make storage more efficient on the database side I think. I remember thinking about this model already when I was investigating how I could store certificate transparency records. Having the separate zone names also might help coming up with iterations for wordlists of FQDNS. Think of queries that can count how many times www appears in the zones belonging to a target. Or sorting on the most common zone names so I can mix and match them. Thinking about the MPF iOS/macOS app, it would also very much be beneficial if we can create a tree view of the zones, similar to burp suite. At the moment I have the big domain table but thats a bit trickier from a querying performance and caching perspective. If I have a collapsed tree view, I wouldn't need to query all data every time, I can just query the top level zones, and only when expanding, retrieve the sub zones. If I combine this with a caching mechanism AND my idea of a better REST API that can qujery for the changes since a specific timestamp, it can become very flexible and scalable. If the changes since timestamp also includes a list of zone/fqdn IDs that are deleted since the timestamp, I can also remove them from the tree view. I can still have a button that just refreshes everything as well to clear up inconsistencies. I'm not sure what the performance of a recursive lookup is in a table though, something to read up on. I also need to think about the storage in this case. Lets say I have two fqdns belonging to different targets: app.www.target1.com web.www.target2.com target1.com and target2.com will be separate zones without a parent (Makes no sense to store top level domains I recon?) How should I store the www zone in this case? Would I have a single www zone with multiple parents (target1.com and target2.com). Or would i create separate www zones for each target? I think it will be better to have separate www zones to be honest as I would want to couple fqdn's to a specific zone ID. I would also want to store the NS records and SOA records for that particular zone (IF it has any, else we should rely on the parent NS records instead.) UPDATE: having the zones and their NS servers will help make our DNS resolving more robust. We can directly query the authoritive DNS servers for the most recent udpates. #### DNS Zone lookups What would also be nice is if we can have NS lookup chains visualised somehow for domains. How would this be visualized in my app? rabobank.com -> NS 10.0.0.1, 10.0.0.2 -> aws.rabobank.com -> NS 192.168.0.1 -> cloud.aws.rabobank.com -> NS aws-123.ns.aws.com We will traverse the FQDN hierarchy until we don't find any more different NS records. Lets write it out again using our AWS setup *rabo.cloud* a20-66.akam.net a18-65.akam.net a3-67.akam.net a7-64.akam.net a1-191.akam.net a12-64.akam.net. *aws.rabo.cloud* ns-819.awsdns-38.net. ns-1552.awsdns-02.co.uk. ns-21.awsdns-02.com. ns-1267.awsdns-30.org. *azure.rabo.cloud* ns3-06.azure-dns.org. ns4-06.azure-dns.info. ns2-06.azure-dns.net. ns1-06.azure-dns.com. *sandbox.aws.rabo.cloud* ns-671.awsdns-19.net. ns-457.awsdns-57.com. ns-1937.awsdns-50.co.uk. ns-1362.awsdns-42.org. The way I could scan for this is traverse the domain name dots and query for SOA records. The SOA record will give a hint on the hierarchy, and then I could resolve the NS records that are part of it. See these dig commands for example: ```sh > dig blabla.sandbox.aws.rabo.cloud AUTHORITY SECTION: sandbox.aws.rabo.cloud. 900 IN SOA ns-1937.awsdns-50.co.uk. > dig blabla.blabla.aws.rabo.cloud AUTHORITY SECTION: aws.rabo.cloud. 900 IN SOA ns-1552.awsdns-02.co.uk. > dig blabla.blabla.blabla.rabo.cloud AUTHORITY SECTION: rabo.cloud. 600 IN SOA a1-191.akam.net. ``` ## Latest Reply from Copilot Got it! Let's update the `webapp` table to use a vector of Service IDs instead of the `fqdn_tcp` and `ipv4_tcp` composites. We'll also remove the `host_sni` composite and add `hostname` and `sni` columns that can both be null. ## Wordlists Some wordlists could perhaps benefit from being sorted by hit ratio and/or criticality. Should I normalize the words in a wordlist and assign hit rate and criticality to them? I can then have the same words referenced in different wordlists, and sort a wordlist by hit rate or criticality. Also would have to think about hit rate across targets. That is very likely to be different so how would I treat that? Some kind of base hit rate and a mutating factor of hit rate per target? Or a hit rate per target and then generate the average hit rate across all targets? ### Updated Table Definitions Reply from Copilot: Got it! Let's update the `webapp` table to use a vector of Service IDs instead of the `fqdn_tcp` and `ipv4_tcp` composites. We'll also remove the `host_sni` composite and add `hostname` and `sni` columns that can both be null. 1. **targets** ```sql CREATE TABLE targets ( id SERIAL PRIMARY KEY, name VARCHAR(255) NOT NULL, description TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ``` 2. **fqdns** ```sql CREATE TABLE fqdns ( name VARCHAR(255) PRIMARY KEY, target_id INTEGER, cnames TEXT[], -- Array of CNAMEs ip_addresses inet[], -- Array of IP address IDs reversed_name VARCHAR(255), -- Column for reversed FQDN for suffix searches created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ``` 3. **ipv4_addresses** ```sql CREATE TABLE ipv4_addresses ( address inet PRIMARY KEY, -- Using inet to store IPv4 address type VARCHAR(50) CHECK (type IN ('generic', 'private_ip', 'akamai', 'azure', 'aws', 'gcp', 'cloudfront')), tcp_ports INTEGER[], technology_ids INTEGER[], -- Array of technology IDs created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ``` 4. **technologies** ```sql CREATE TABLE technologies ( id SERIAL PRIMARY KEY, name VARCHAR(255) NOT NULL, version VARCHAR(50), description TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ``` 5. **networks** ```sql CREATE TABLE networks ( id SERIAL PRIMARY KEY, cidr inet NOT NULL, type VARCHAR(50) CHECK (type IN ('generic', 'private_ip', 'akamai', 'azure', 'aws', 'gcp', 'cloudfront')), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ``` 6. **services** ```sql CREATE TABLE services ( id SERIAL PRIMARY KEY, ipv4_address inet NOT NULL, -- Using inet to store IPv4 address port INTEGER NOT NULL, port_type VARCHAR(10) CHECK (port_type IN ('tcp', 'udp', 'both')), protocol VARCHAR(50), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ``` 7. **webapps** ```sql CREATE TABLE webapps ( id SERIAL PRIMARY KEY, root_path VARCHAR(255) NOT NULL, target_id INTEGER, service_ids INTEGER[], -- Array of Service IDs technology_ids INTEGER[], -- Array of technology IDs hostname VARCHAR(255), -- Hostname, can be null sni VARCHAR(255), -- SNI, can be null created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ``` ### Trigger Functions and Triggers 1. **Trigger Function for Timestamps** ```sql CREATE OR REPLACE FUNCTION update_timestamps() RETURNS TRIGGER AS $$ BEGIN NEW.updated_at = CURRENT_TIMESTAMP; IF TG_OP = 'INSERT' THEN NEW.created_at = CURRENT_TIMESTAMP; END IF; RETURN NEW; END; $$ LANGUAGE plpgsql; ``` 2. **Trigger Function for Reversed Name** ```sql CREATE OR REPLACE FUNCTION update_reversed_name() RETURNS TRIGGER AS $$ BEGIN NEW.reversed_name := reverse(NEW.name); RETURN NEW; END; $$ LANGUAGE plpgsql; ``` 3. **Trigger Function for Unique and Sorted Arrays** ```sql CREATE OR REPLACE FUNCTION sort_and_unique_array() RETURNS TRIGGER AS $$ BEGIN NEW.ip_addresses := ARRAY(SELECT DISTINCT UNNEST(NEW.ip_addresses) ORDER BY 1); RETURN NEW; END; $$ LANGUAGE plpgsql; ``` 4. **Trigger Function to Update IP Address Type** ```sql CREATE OR REPLACE FUNCTION update_ip_address_type() RETURNS TRIGGER AS $$ DECLARE network_type VARCHAR(50); BEGIN SELECT type INTO network_type FROM networks WHERE NEW.address << cidr ORDER BY masklen(cidr) DESC LIMIT 1; IF network_type IS NOT NULL THEN NEW.type := network_type; ELSE NEW.type := 'generic'; END IF; RETURN NEW; END; $$ LANGUAGE plpgsql; ``` 5. **Trigger Function to Update All IP Address Types** ```sql CREATE OR REPLACE FUNCTION update_all_ip_address_types() RETURNS TRIGGER AS $$ BEGIN UPDATE ipv4_addresses SET type = COALESCE( (SELECT type FROM networks WHERE ipv4_addresses.address << networks.cidr ORDER BY masklen(networks.cidr) DESC LIMIT 1), 'generic' ); RETURN NULL; END; $$ LANGUAGE plpgsql; ``` 6. **Triggers for Each Table** ```sql CREATE TRIGGER trg_update_targets_timestamps BEFORE INSERT OR UPDATE ON targets FOR EACH ROW EXECUTE FUNCTION update_timestamps(); CREATE TRIGGER trg_update_fqdns_timestamps BEFORE INSERT OR UPDATE ON fqdns FOR EACH ROW EXECUTE FUNCTION update_timestamps(); CREATE TRIGGER trg_update_fqdns_reversed_name BEFORE INSERT OR UPDATE ON fqdns FOR EACH ROW EXECUTE FUNCTION update_reversed_name(); CREATE TRIGGER trg_sort_and_unique_ip_addresses BEFORE INSERT OR UPDATE ON fqdns FOR EACH ROW EXECUTE FUNCTION sort_and_unique_array(); CREATE TRIGGER trg_update_ipv4_addresses_timestamps BEFORE INSERT OR UPDATE ON ipv4_addresses FOR EACH ROW EXECUTE FUNCTION update_timestamps(); CREATE TRIGGER trg_update_ip_address_type BEFORE INSERT OR UPDATE ON ipv4_addresses FOR EACH ROW EXECUTE FUNCTION update_ip_address_type(); CREATE TRIGGER trg_update_all_ip_address_types AFTER INSERT OR UPDATE ON networks FOR EACH ROW EXECUTE FUNCTION update_all_ip_address_types(); CREATE TRIGGER trg_update_webapps_timestamps BEFORE INSERT OR UPDATE ON webapps FOR EACH ROW EXECUTE FUNCTION update_timestamps(); ``` ### Index for Suffix Searches Create a B-Tree index on the `reversed_name` column for efficient suffix searches: ```sql CREATE INDEX idx_fqdns_reversed_name ON fqdns(reversed_name); ``` ### Example Data Insertion ```sql INSERT INTO fqdns (name, target_id, cnames, ip_addresses) VALUES ('www.rabobank.com', 1, '{}', '{}'), ('api.rabobank.com', 1, '{}', '{}'), ('mail.rabobank.com', 1, '{}', '{}'), ('www.example.com', 2, '{}', '{}'); INSERT INTO ipv4_addresses (address, type, tcp_ports, technology_ids) VALUES ('192.168.1.1', 'generic', '{80, 443}', '{1, 2}'), ('192.168.1.2', 'generic', '{80, 443}', '{3, 4}'); INSERT INTO technologies (name, version, description) VALUES ('nginx', '1.18.0', 'High performance web server'), ('docker', '20.10.7', 'Containerization platform'), ('apache', '2.4.46', 'Web server software'), ('kubernetes', '1.21.0', 'Container orchestration platform'); INSERT INTO networks (cidr, type) VALUES ('192.168.1.0/24', 'private_ip'), ('10.0.0.0/8', 'private_ip'), ('172.16.0.0/12', 'private_ip'), ('203.0.113.0/24', 'akamai'); INSERT INTO services (ipv4_address, port, port_type, protocol) VALUES ('192.168.1.1', 80, 'tcp', 'http'), ('192.168.1.1', 443, 'tcp', 'https'), ('192.168.1.2', 80, 'tcp', 'http'), ('192.168.1.2', 443, 'tcp', 'https'); INSERT INTO webapps (root_path, target_id, service_ids, technology_ids, hostname, sni) VALUES ('/app1', 1, '{1, 2}', '{1, 2}', 'www.rabobank.com', 'sni.rabobank.com'), ('/app2', 1, '{3, 4}', '{3, 4}', 'mail.rabobank.com', 'sni.mail.rabobank.com'); ``` ### Example: Adding New IPv4 Address and Automatically Updating Type To add a new IPv4 address and illustrate that the `type` is filled in automatically: ```sql INSERT INTO ipv4_addresses (address, tcp_ports, technology_ids) VALUES ('192.168.1.3', '{80, 443}', '{1, 2}'); ``` After this insertion, the `type` field for the IP address `192.168.1.3` ``` ### `book/src/README.md` **Absolute path:** `/root/mrpf/book/src/README.md` ```markdown # Introduction My Rusty Pension Fund is my attempt at writing tools for my bug bounty hobby. I have scanners, fuzzers, wordlist generators, task managers, iOS and macOS apps, the WORKS! Although most of it doesn't really work, neither have I ever found a real bug with any of these tools. Still, all good fun to build and learn stuff! Its focus is to build a very fast and memory/cpu optimized scanner for various scenarios. The main initial focus are specialized scanners for: - TCP Syn port scanning - TLS certificate Subject Alt Name scraper - HTTP(S) scanning The Rust library will implement it's own custom network stack using libpnet to be able to optimize everything as much as possible, similar to the masscan service. The network stack should be easily extensible, adding new scan types in the future that leverage the custom network stack. The core concept of the network stack is that it uses libpnet in datalink layer and contruct my own ethernet, ip and tcp packets. It will be running in two separate threads, one for sending packets and another for receiving packets. Similar to masscan, synchronisation of the send and receive thread is avoided by using the tcp sequence and ack numbers with a hash to keep track of what sending packet corresponds to what received packet. ## How do the libraries work together Main principle is: All MRPF tools like scanning, filtering, fuzzing logic should be in external libs (eg. tcp_syn_scanner, cert_transparency, matchers, http1_scanner, etc). This will allow us to build different 'front ends' for these tools. ### MRPF Jobs - The mrpf_core::tasks is the only place that defines Tasks. - Task pull in all external MRPF tool libs in and executes the them in the execute function of the relevant tasks. - The worker nodes are very simple. They pull in the mrpf_core::tasks library and run the tasks. With this architecture, it should be possible to have a single lambda that can execute any arbitrary task. This would make the AWS specific code very small, allowing us to move to other workers without any real effort. TBD if we would separate the MRPF API for scheduling jobs as a separate api or integrate it with the MRPF API. A compromise could be to build a separate API, but do include endpoints for both MRPF API and MRPF Jobs into the MRPF API Client. ### WebSocket interface for iOS/macOS frontend The mrpf_scanner_api provides a WebSocket interface for real-time communication between iOS and macOS applications. It allows you to schedule tools and receive updates on their progress. ### MRPF API This is an API that stored all our recon data. ### MRPF API Client This is an async client for interacting with the MRPF API. ### CLI The mrpf_cli is a command-line interface for interacting with the MRPF framework. It allows users to initiate scans, manage tasks, and retrieve results directly from the terminal. At the moment it hasn't really been implemented yet. ## How will we build REST API's? I want it to be reasonably easy to move away from AWS for my REST API's. I'd also want to try and save costs as API Gateway can get expensive. You could run a Rust API fully inside a single lambda. There are two well known paths for this: - Use the official lambda_http crate (ALB / API Gateway / Function URLs) - Run a normal Axum/Hyper server inside Lambda via the AWS Lambda Web Adapter (eg. check this blog and AWS reInvent own slides here: ) ## Remember 'The Algorithm' Elon musk's algorithm, should be applied to everything, in this specific order: 1. Question every requirement What things do I REALLY need? - Continuously find attack surface - Quickly fuzz endpoints with payloads - Detect anomalies in responses What things don't I really need but am I often looking for? - Store all the results - Have a fancy UI 2. Delete any part of the process you can The most common thing that an engineer does is optimizing a thing that shouldn't exist. 3. simplify and optimize 4. Accelerate cycle time 5. Automate ``` ### `book/src/SUMMARY.md` **Absolute path:** `/root/mrpf/book/src/SUMMARY.md` ```markdown # Summary [Introduction](README.md) - [Core Principles](core_principles.md) - [System Architecture](system_architecture/README.md) - [The MRPF API](api/README.md) - [Network Engine](network_engine/README.md) - [Infrastructure](network_engine/infrastructure.md) - [TLS Scraper](network_engine/tls_scraper.md) - [HTTP/1.1 scanner](network_engine/http11_scanner.md) - [Tcp Syn Scanner](network_engine/tcp_syn_scanner.md) - [DNS Resolver](network_engine/dns_resolver.md) - [Whois Resolver](network_engine/whois_resolver.md) - [Ideas for Improvement](network_engine/ideas/README.md) - [Refactor network stack](network_engine/ideas/refactor_network_stack.md) - [MY FIRST ORIGINAL IDEA FOR BETTER THROUGHPUT](network_engine/ideas/MY_FIRST_ORIGINAL_IDEA_FOR_BETTER_THROUGHPUT.md) - [Reducing pselect calls](network_engine/ideas/recucing_pselect_calls.md) - [TLS parse improvements](network_engine/ideas/tls_parse_improvements.md) - [Task Manager](task_manager/README.md) - [Infrastructure](task_manager/infrastructure.md) - [Tasks and Containers](task_manager/tasks_and_containers.md) - [Data](task_manager/data.md) - [Workers](task_manager/workers.md) - [Task Generation and Aggregation](task_manager/task_generators.md) - [Data Templates](task_manager/templates.md) - [Examples](task_manager/examples.md) - [Apple macOS/iOS Universal App](apple_universal_app/README.md) - [MRPF Scanner API](apple_universal_app/mrpf_scanner_api.md) - [Certificate Transparency Records](cert_transparency_records/README.md) - [TODO](cert_transparency_records/TODO.md) - [Coding Guidelines](coding_guidelines/README.md) - [Error handling](coding_guidelines/error_handling.md) - [Security Ideas](security_ideas/README.md) - [AI Wordlists](security_ideas/ai_wordlists.md) - [Protections implemented by AWS for scanning](security_ideas/aws_protections.md) - [Random ideas](security_ideas/random_ideas.md) - [Tricks using DNS lookups for S3 recon](security_ideas/s3_dns_recon.md) - [MPF - Previous iteration](mpf_previous_iteration/README.md) - [What didn't work well](mpf_previous_iteration/what_works_and_what_not.md) - [Contributing](contributing.md) ``` ### `book/src/api/README.md` **Absolute path:** `/root/mrpf/book/src/api/README.md` ```markdown # The MRPF API The MRPF API allows clients to programmatically interact with the MRPF platform. It provides endpoints managing recon data like targets, domains and wordlists, as well as triggering tasks. ## Current State At the moment the code is still running on MPF Python codebase, with a DocumentDB backend. I would love to get this into rust for better performance and alreaady have some of the models defined in mrpf_models. Some things I want to work on: - Revisit the templating engine for rust. Think about how to represent things, our wordlist probably need to work better with bytes and then have methods to change things to utf-8/16/etc where applicable - Move away from DocumentDB to PostgreSQL. This will give me back the triggers for timestamps that I very much like. Also, DynamoDB for at least transparency records was just to costly so lets get back to the drawing board ## Ideas and Future Work ### Had some insights? For my MRPF API, I think I might be too quickly trying to push everything in full predefined structs. However, when reading and writing data, I often only want to have a subset: - list all active fqdns of a target id hmm, is it true? Is this the only real example I've found? Ok, lets think about the write queries: - tcp syn scan needs to append ports to an existing ip address - SNI scanner needs to create new fqdn objects and services (ip/port that the sni was found on) - Http scanner needs to update the WebApp content hash - CrtSh needs to create fqdn objects - DNS resolver needs to update fqdn objects, create new ones found through PTR, or update zones with their NS and SOA records All these things can be done with my current task/job manager BUT are these actually not better to run continuously? Scans with larger amount of data can better bypass rate limits due to more randomization. Easier to alert when a new domain has been found? ``` ### `book/src/apple_universal_app/README.md` **Absolute path:** `/root/mrpf/book/src/apple_universal_app/README.md` ```markdown # The Apple universal iOS/macOS app To make it easier to work with my recon data and task scheduler I've created a universal iOS/macOS app in Swift. It provides a nice frontend for all my tools and data. The alternative was to build some kind of web frontend, but truth be told, I just don't enjoy writing Javascript. Swift, and especially SwiftUI, feels a lot more fun and rewarding to build things with and can have a lot more focussed user experience on phones. The current iteration works with the older MPF API and task manager, but I want to move this to the new MRPF API and task manager once I've built that out a bit more. ## HTTP Repeater/scanner The MRPF Scanner API provides a websockets interface into the various scanners built on top of the MRPF network engine. The current macOS app is able to interact with it and I'm trying to work towards a similar functionality as the repeater in Caido and Burp. Instead of making single requests, I've built the templating engine into it, similar to how I'm constructing the task manager. This allows you to more easily fuzz things, I guess it's more akin to the intruder in Burp. I feel I should be able to find a better balance than Caido and Burp for the UI, and feel a mixture between requests and the intruder tab is getting me amost there. I need to iterate more. Other things that would really help with it is a more mature wordlist generators from the app. The killer feature will be my ability to bring together the task manager for scanning, all the collected data and the repeater/inspector in one app. ## Whats the current status? The iOS/macOS app needs work, this would be really nice to give a big refactor but I want to leverage the latest macOS 26 version. This also introduces copilot directly in XCode so should help me learn Swift and best practices a lot faster. The time to be a solo developer is now, finally I'm able to build everything myself if I just manage to keep focus on the things I really want to move forward.. - Better handling of textarea in my ‘burp’ mimicking feature - Revisit the job template composition. There’s a bunch of inefficient strange code, which I think I should be able to make more ergonomic in the swift language. All those casting, generics and codable stuff is a mess - Fully buy into the two column Split View and make the macOS design aligned with liquid glass. Althernative might be to switch completely to a Tabview design. Apparently on iPadOS this tab view now transforms in a sidebar automatically, not sure if this carries over to macOS as well? - macOS works ok-ish but iOS is lagging behind. What do I want to do, it probably needs a few different design patterns to work well on the platform. Some actions might just not be suited for a phone either. - Make a more robust wordlist section. Especially Apple's easy integration with language model can be very helpful here to generate new wordlists on the fly. I also need to dig into the wordlist problem a lot deeper and try to take it up to a more professional level. I need to be able to support different encodings, rate words by potential impact, link things across targets, think about efficient storage and retrieval in the database, full integration with the templating engine, etc. ``` ### `book/src/apple_universal_app/mrpf_scanner_api.md` **Absolute path:** `/root/mrpf/book/src/apple_universal_app/mrpf_scanner_api.md` ```markdown # MRPF Scanner API This is a websockets interface to the different scanners built on top of the MRPF network engine. It can be run on any machine that has Rust, at the moment focussed on running it on my mabook itself, but I can see it being useful to run on a VM in the cloud. It would be good if we can somehow get the VM being part of the workers of our task manager as these things have some overlap. The MRPF Scanner API should be only a frontend for the scanners, not do any scanning itself to keep separation of concerns. Running websockets on any server is great as we can have bare-metal workers this way (or for instance my macbook). What would also be nice though is to actually use AWS Lambda for certain tasks here as well. We could leverage AWS WebSocket API Gateway with a lambda backing. Apparently all the keepalive stuff is handled by API gateway, you only pay for real messages and the lambda execution time. ``` ### `book/src/cert_transparency_records/README.md` **Absolute path:** `/root/mrpf/book/src/cert_transparency_records/README.md` ```markdown # Certificate Transparency Records The Certificate Transparency records produced by the big certificate issuers are a goldmine for finding new domains. The most popular way to retrieve this is through crt.sh website or better yet through their PostgreSQL database. However, they have stricter rate limits and more difficulty getting all the records returned for larger subdomains. We can do this better so I've written my own code that can scrape the certificate transparency logs directly from the issuers. However, the main problem eventually is with costs. I wanted to use AWS DynamoDB for this, and although I got it to work and learned a lot by how to model things there, it turns out it's quite costly for this usecase. I am better off moving this to PostgreSQL. Also, the lambda invocation costs are quite high so makes more sense to run the initial scraping of all older logs on EC2, my VPS or my macbook. Once the initial bulk is done, we could probably use lambda to keep the incremental updates going. ## Ideas and Future Work ... ``` ### `book/src/cert_transparency_records/TODO.md` **Absolute path:** `/root/mrpf/book/src/cert_transparency_records/TODO.md` ```markdown # TODO I'm afraid I need to rework my approach again. I figured out that the initial connection timeout is very important way to get rid of a lot of rate limits. Because of this, my approach of sending ranges in batches of 1 per log server isn't holding up great. This is causing the loop to wait for the connection timeout fully until it processes a new batch. Instead I should - Create a proper RateLimit class, similar to the interface - Tweak rate limits and connection timeouts find the optimal balance per server. - Provide the scan log servers with a hashmap> - Once a range for a particular log server is completed, pluck another one from the HashMap - If a range fails to be completed in the rate limit timeframe, move to the next range and leave the range as Pending in the database. - See if we can store the logserver optimal rate limits and preferred range sizes in the database. We could base the latter on the average entry count that a log server returns. so servers with 1024 entries per request can use larger range counts than servers that return lower count like 32. Something like LOGSERVER# SK: PROPERTIES. ```rust struct RetryConfig { max_retries: usize = 3 initial_backoff: Duration = 3 max_backoff: Range = [20..22] // range to randomize exponential_backoff: bool = true step: Range = [1..2] // With exponential, we will increase this exponentially, otherwise we will do this linearly. We use a range to randomize } struct LogServer { retry_config: RetryConfig average_entry_size: u16 = 1024 url: String mirror: Optional } struct CertTransparencyClient { log_server: LogServer, retry_config: RetryConfig, max_range_failure: usize = 3 // maximum amount of time a range can fail until we completely want to stop trying } impl CertTransparencyClient { async fn get_entries(range) { // Retrieve entries // retrieve entries // if entries = 0, retry // if entries != 0, reset max_retries // if MaxRangeRetries, return error. user of the certclient should skip this range and try another range // if MaxLogServerRetries, return error. of of the certclient should abandon this log server completely // if range completed, return all entries that are retrieved OR // Do we want to handle the name extraction here already? Maybe this code above should be the function get_names, and get_entries should return RateLimited and connection errors errors. Then this function will parse those errors and either, retry, raise MaxLogServerRetries or MaxRangeRetries, or return the entries. } } ``` ``` ### `book/src/coding_guidelines/README.md` **Absolute path:** `/root/mrpf/book/src/coding_guidelines/README.md` ```markdown # Coding Guidelines See .github/copilot-instructions.md for the full coding guidelines and project overview. ## Faster macOS builds You can get faster builds on mac by disabling XProtect on the terminal app. [read this](https://nnethercote.github.io/2025/09/04/faster-rust-builds-on-mac.html) Here's some better docs on how to add terminal to developer tools: ``` ### `book/src/coding_guidelines/error_handling.md` **Absolute path:** `/root/mrpf/book/src/coding_guidelines/error_handling.md` ```markdown # Error Handling My prompt: I am writing a large rust codebase with several libraries and binaries. The code will only be used by myself so doesn't have to take external users into account at all. I want to have a proper way of handling errors without having to write too much boiler plate code and matches everywhere. What are best practices for error handling in my scenario and what do experts advice on this? I would prefer not to bring in external crates. Perhaps jon Gjengset has good advice or others? - Should I use ? as much as possible - What should my typical return type be? Result - How to handle/catch Errors defined by other external crates and capture them easily upstream without too much hassle - Should i define my own error types to make this easier? - Is there a way to avoid using Box all over the place? Think hard about maintainability for my requirements and being ideomatic rust ## The 80/20 pattern (no external crates) I suspect this 80/20 line comes from [this blog](https://masteringbackend.com/posts/rust-error-handling-80-20-guide). The Rust by Example book has [a chapter](https://doc.rust-lang.org/rust-by-example/error/multiple_error_types/wrap_error.html) that looks a lot like this approach. ### Per library crate 1) Define one **crate‑local error enum** and a `Result` alias. 2) Implement `Display`, `std::error::Error`, and a few `From<_>` conversions for external error types you actually use. 3) Lean on `?` everywhere; it auto‑converts via `From`. [1](https://doc.rust-lang.org/reference/expressions/operator-expr.html)[2](https://doc.rust-lang.org/std/error/trait.Error.html) ### Per binary crate Return `Result<(), Box>` from `main` (or use your own top‑level error). This is the **only** place you need a `Box`, keeping it out of the rest of your code. Rust’s `Termination` trait makes `main` returning a `Result` work as expected (nonzero exit when `Err`). [3](https://doc.rust-lang.org/std/process/trait.Termination.html) ### Minimal template ```rust // lib/src/error.rs use std::{error::Error as StdError, fmt}; #[derive(Debug)] pub enum Error { Io(std::io::Error), Json(serde_json::Error), // Add more *typed* variants as needed. // Optionally a catch‑all if you want (kept rare): Other(String), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Error::Io(e) => write!(f, "I/O error: {e}"), Error::Json(e) => write!(f, "JSON error: {e}"), Error::Other(msg) => write!(f, "{msg}"), } } } impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { match self { Error::Io(e) => Some(e), Error::Json(e) => Some(e), Error::Other(_) => None, } } } // Focused conversions so `?` just works: impl From for Error { fn from(e: std::io::Error) -> Self { Error::Io(e) } } impl From for Error { fn from(e: serde_json::Error) -> Self { Error::Json(e) } } // Your crate-wide Result alias pub type Result = std::result::Result; ``` ```rust // lib/src/lib.rs mod error; pub use error::{Error, Result}; use std::fs; pub fn load_config(path: &str) -> Result { let data = fs::read_to_string(path)?; // auto-converts io::Error -> Error via From let v = serde_json::from_str(&data)?; // auto-converts serde_json::Error -> Error via From Ok(v) } ``` ```rust // bin/src/main.rs use mylib::{load_config, Error}; fn main() -> Result<(), Box> { let cfg = load_config("config.json")?; // prints a neat chain on failure println!("{cfg}"); Ok(()) } ``` - Why this works well: - `?` is the standard way to **propagate** errors; it’s concise and idiomatic. [1](https://doc.rust-lang.org/reference/expressions/operator-expr.html)[4](https://doc.rust-lang.org/book/ch09-00-error-handling.html) - Implementing `std::error::Error` + `Display` lets errors **compose** and present well. [2](https://doc.rust-lang.org/std/error/trait.Error.html) - Keeping `Box` **only at the binary boundary** avoids “Box everywhere” while still letting `main` fail gracefully with a proper exit code via `Termination`. [3](https://doc.rust-lang.org/std/process/trait.Termination.html) --- ## Adding lightweight context (still no external crates) If you sometimes want more context (e.g., *which* file failed to parse), add a single contextual variant and a tiny helper trait: ```rust // in error.rs use std::borrow::Cow; #[derive(Debug)] pub enum Error { Io(std::io::Error), Json(serde_json::Error), Context { msg: Cow<'static, str>, source: Box }, Other(String), } impl From for Error { /* as before */ } impl From for Error { /* as before */ } // Extension trait to tack on context ergonomically pub trait ResultExt { fn context(self, msg: impl Into>) -> Result; } impl ResultExt for std::result::Result where E: StdError + Send + Sync + 'static, { fn context(self, msg: impl Into>) -> Result { self.map_err(|e| Error::Context { msg: msg.into(), source: Box::new(e) }) } } ``` Usage: ```rust use crate::error::ResultExt; fn load_config(path: &str) -> Result { let data = std::fs::read_to_string(path) .context(format!("reading config from {path}"))?; serde_json::from_str(&data) .context(format!("parsing JSON from {path}")) } ``` - This keeps your enum **typed** for the common cases and uses one **boxed** variant **only when you add context**, not “all over the place.” - If you want stack traces during debugging, you can attach a `std::backtrace::Backtrace` field to `Context` and enable it with `RUST_BACKTRACE=1` (or `RUST_LIB_BACKTRACE=1`). [5](https://doc.rust-lang.org/std/backtrace/index.html) --- ## Answers to your specific questions ### 1) “Should I use `?` as much as possible?” **Yes.** Use `?` whenever you cannot or do not want to **recover** at that level; it’s the idiomatic, concise way to propagate errors, and it triggers `From` to convert error types along the way. [1](https://doc.rust-lang.org/reference/expressions/operator-expr.html) Reserve explicit `match` only for cases where you’re **handling** the error (retrying, defaulting, logging with additional side‑effects, etc.). The Book’s error chapter lays out this philosophy (recoverable `Result` vs unrecoverable `panic!`). [4](https://doc.rust-lang.org/book/ch09-00-error-handling.html) ### 2) “What should my typical return type be?” - **In library crates:** `pub type Result = std::result::Result;` with a **crate‑specific `Error` enum**. This keeps signatures short and type‑checked. [2](https://doc.rust-lang.org/std/error/trait.Error.html) - **In binary crates:** either reuse your library `Error` or keep the boundary generic: `fn main() -> Result<(), Box>` (simplest), relying on `Termination` to convert `Err` to a non‑zero exit code. [3](https://doc.rust-lang.org/std/process/trait.Termination.html) ### 3) “How do I capture external crate errors upstream without hassle?” Add **`From` impls** to your error enum; then `?` “just works.” Also, implement `source()` to expose the underlying error so higher layers (or your CLI) can print the chain. The standard library’s `Error` trait is designed for this chaining. [2](https://doc.rust-lang.org/std/error/trait.Error.html) ```rust impl From for Error { fn from(e: reqwest::Error) -> Self { Error::Other(e.to_string()) } } // or better: add a dedicated variant `Http(reqwest::Error)` and a Display arm ``` When adding **context**, use the small `ResultExt::context` trick above rather than writing verbose `map_err` everywhere. ### 4) “Should I define my own error types to make this easier?” **Yes.** A **domain‑specific error enum per crate** is a widely used pattern. It avoids stringly‑typed errors, makes it easy to distinguish cases, and composes well. The Book/`std::error::Error` docs show how to implement `Display` + `Error` so they print nicely and chain correctly. [2](https://doc.rust-lang.org/std/error/trait.Error.html)[4](https://doc.rust-lang.org/book/ch09-00-error-handling.html) If you expect your library crates to evolve, you can mark the enum `#[non_exhaustive]` (optional, more relevant for public APIs). ### 5) “Is there a way to avoid using `Box` all over the place?” **Absolutely.** Keep errors **typed** (enums) throughout your libs; use at most **one** `Box`: - at the binary boundary (`main`) for convenience, and/or - in a single **contextual** variant (if you want arbitrary sources + messages). Everywhere else, use **concrete variants**, which are zero‑cost and ergonomic with `?`. ## A few “expert‑style” tips - **When to `panic!`**: only for programmer‑bugs / invariant violations (e.g., `unreachable!()`, “this should never happen”), not for I/O or user‑driven failures. The Book explicitly distinguishes recoverable from unrecoverable errors. [4](https://doc.rust-lang.org/book/ch09-00-error-handling.html) - **Display vs Debug**: make `Display` short and human‑readable; rely on `source()` chains for details. `Debug` can be derived. [2](https://doc.rust-lang.org/std/error/trait.Error.html) - **Backtraces** in development: consider adding a `Backtrace` field (captured on construction) to your contextual variant and enable it via env vars when you need it. [5](https://doc.rust-lang.org/std/backtrace/index.html) - **`main` ergonomics**: returning `Result` from `main` is fine and standard; the `Termination` trait glues it to exit codes. If you need full control over formatting, handle the top error yourself instead of returning it. [3](https://doc.rust-lang.org/std/process/trait.Termination.html) ## If you ever relax “no external crates” Many experienced Rustaceans use: - **`thiserror`** (derive `Error` + `Display` + `#[from]` on variants) to **eliminate boilerplate** in library error enums. [6](https://docs.rs/thiserror/latest/thiserror/) - **`anyhow`** in **binary crates** for an ergonomic `anyhow::Error` + `.context()` builder (you can still keep typed errors in libs). [7](https://docs.rs/anyhow/latest/anyhow/) You don’t *need* them for your scenario, but they are the shortest path to zero boilerplate if you change your mind. ## TL;DR (what to do tomorrow) 1) **Per lib**: create `Error` enum + `Result` alias; implement a handful of `From` impls and `Display`. 2) **Use `?`** everywhere you propagate; `map_err` only when you genuinely handle/transform; optional `ResultExt::context` to add messages. [1](https://doc.rust-lang.org/reference/expressions/operator-expr.html) 3) **Per bin**: `fn main() -> Result<(), Box>` (or your top‑level typed error) and let `Termination` handle exit codes. [3](https://doc.rust-lang.org/std/process/trait.Termination.html) 4) Keep `Box` at the edges, **not** in your core APIs. If you want, I can **generate the error scaffolding** for each of your crates (enums + `From` impls + `ResultExt`) based on the external errors you use—just tell me which crates pop up most (e.g., `tokio`, `reqwest`, `serde_json`, `rusqlite`, etc.). Also, are you OK with a single derive crate (`thiserror`) if it replaces ~50–100 lines of boilerplate per crate, or do you want to stay 100% std-only? [6](https://docs.rs/thiserror/latest/thiserror/) ``` ### `book/src/contributing.md` **Absolute path:** `/root/mrpf/book/src/contributing.md` ```markdown # Contributing At the moment the project is solely maintained by me, with the purpose of learning and experimenting with Rust, distributed systems and bug bounties. I haven't got any intentions of putting this out into the world as I'd like to be able to break things when I want and work on my own pace. However, I love to talk to like minded people. If you happened to come across my hidden little corner of the internet, feel free to get in [touch](https://x.com/thiezn_). ``` ### `book/src/core_principles.md` **Absolute path:** `/root/mrpf/book/src/core_principles.md` ```markdown # Core Principles Computer systems, protocols and applications are all built on abstractions. These abstractions help us reason about higher level concepts and speed up development by hiding complexity and to avoid reinventing the wheel. However, in security research it's crucial to understand the underlying reality behind these abstractions. Security vulnerabilities often arise from a mismatch between these layers of abstraction. I love to understand how things really work. When you start to dig into the lower level systems, you develop a good intuition for how a system operates. I've always had a tendency to want to build everything from scratch, as if its cheating to use something that you don't fully understand. This can hold you back but now with the advent of AI, learning and building systems has become much easier. My goal with MRPF is to try to keep the underlying reality front and center. This sometimes comes at the cost of less intuitive or more verbose interfaces, but differentiates the toolset from most other tools out there. TODO: Write my blog around misconceptions around hosts, ips, domains, root paths, dns, TLS SNI etc. ``` ### `book/src/mpf_previous_iteration/README.md` **Absolute path:** `/root/mrpf/book/src/mpf_previous_iteration/README.md` ```markdown # MPF - The previous iteration Before there was MRPF, there was MPF. The original idea was the same, build my own tooling around bug bounty hunting. It was built in Python and I've learned a lot from building it. I wanted to investigate if I could build a custom network stack and have more control over HTTP traffic, but Python was a bit more limited here. For instance, a lot of the HTTP client libraries don't allow for a lot of customization around TLS. The libraries often build abstractions around the networking layers, making it difficult to customize things like TLS SNI or ALPN. Also, concepts like domain names, host names and ip addresses are all mixed together. This is nice fro a user perspective, but I want very specific control over all these factors to find misconfigurations. Initially I started to build a custom network stack in C using libuv. Unfortunately my laptop crashed and I was stupid enough not to commit all that code into a repo. Also, I was very much struggling with writing save concurrent C code. After a while I decided to try and rebuild the network stack and looked into Rust. By this time ChatGPT was really getting good and it really helped me quickly get up to speed with a new language. I learned a LOT and started to love some of the rust concepts like ownership, fearless concurrency and the way it makes refactoring code bases a lot easier. This now has let me to try and re-build MPF completely in rust, Hence the M(y) (Rusty) P (Pension) F (Fund) project. ``` ### `book/src/mpf_previous_iteration/what_works_and_what_not.md` **Absolute path:** `/root/mrpf/book/src/mpf_previous_iteration/what_works_and_what_not.md` ```markdown # What are the things that don't work that well in my current iteration of MRF - The task manager parallelization is not optimal - The memory management of building parallel tasks is not optimal - Python is quite memory hungry for large scans - All scanning tasks work on Lambda, I can't mix in bare metal or containers - Job scheduling is helpful but also very repetitive for each target. Would be better to have generic continuous scanning for all targets. This will help spread out load as well with my new randomization rust scanners - The database model has some limitations: - Nothing showing where certain results came from - Not possible to construct the mermaid graph representation I came up with - Too much things stacked inside the Domain object that are not 100% correct. For example, IP addresses should be their own entities, tcp/udp ports are related to an ip not a domain, the order of resolved IP addresses in the domain object is not static, making it seem like we have a lot of updates. - The task manager code is quite difficult to read and not confident it's robust enough. - Introducing new tasks is quite labor intensive - Tasks do not easily show the task template they belong to, making parsing log files more difficult - The statistics of all the scans happening are not easily accessible or useful - THE MAIN THING, it hasn't helped me a single time to find or get closer to any actual bug/bounty. I have learned a bunch of things though so that is something.. ## Lets watch the Beazley talk and build my job along side it ``` ### `book/src/network_engine/README.md` **Absolute path:** `/root/mrpf/book/src/network_engine/README.md` ```markdown # Network Engine The core of my network stack is based on masscan. Here's a diagram of the three threads used: ## Sending Thread Construct Ethernet Packet: Create the Ethernet frame with appropriate source and destination MAC addresses. Construct IP Packet: Create the IP packet with source and destination IP addresses. Generate Sequence Number: Generate a unique sequence number based on the source/destination IP and port pairs. Create TCP Packet: Construct the TCP packet with the generated sequence number and other necessary fields. Send Packet: Send the constructed packet over the network. Send Status: Notify the status report thread that a packet has been sent. ## Receiving Thread Listen for Incoming Packets: Continuously listen for incoming packets on the network. Filter Relevant Packets: Filter out packets that are not relevant based on the unique sequence number. Handle Packet: Process the relevant packet (e.g., extract data, acknowledge receipt). Send Status: Notify the status report thread that a packet has been received and handled. ## Status Report Thread Receive Status Updates: Continuously receive status updates from the sending and receiving threads. Update Status and Statistics: Update the current status and statistics of the scan based on the received updates. Print Status and Statistics: Print the updated status and statistics to the console or log. This diagram and description should help visualize the flow and interaction between the threads in your scanning application. If you have any further questions or need additional details, feel free to ask! ``` mermaid graph TD A[Main Thread] -->|Start| B[Sending Thread] A -->|Start| C[Receiving Thread] A -->|Start| D[Status Report Thread] B --> B1[Construct Ethernet Packet] B1 --> B2[Construct IP Packet] B2 --> B3[Generate Sequence Number] B3 --> B4[Create TCP Packet] B4 --> B5[Send Packet] B5 -->|Send Status| D C --> C1[Listen for Incoming Packets] C1 --> C2[Filter Relevant Packets] C2 --> C3[Handle Packet] C3 -->|Send Status| D D --> D1[Receive Status Updates] D1 --> D2[Update Status and Statistics] D2 --> D3[Print Status and Statistics] ``` ## Rate limiting The `transmit_handler` function implements a rate limit bucket algorithm to control the rate at which packets are sent. The rate limit bucket algorithm is a mechanism to control the rate of packet transmission. It works as follows: 1. **Initialization** - A token bucket is initialized with a certain number of tokens (`RATE_LIMIT_PACKETS_PER_INTERVAL`). - Each token represents permission to send one packet. - The bucket is refilled at regular intervals (`RATE_LIMIT_INTERVAL`). 2. **Packet Transmission** - For each packet to be sent, the algorithm checks if there are tokens available in the bucket. - If tokens are available, a token is consumed, and the packet is sent. - If no tokens are available, the algorithm waits until the bucket is refilled. 3. **Refilling the Bucket** - The bucket is refilled at a fixed interval (`RATE_LIMIT_INTERVAL`). - When the interval elapses, the bucket is refilled to its maximum capacity (`RATE_LIMIT_PACKETS_PER_INTERVAL`). 4. **Handling Buffer Full Errors** - If the packet transmission fails due to a full buffer (`NO_BUFFER_SPACE_AVAILABLE_ERROR`), the algorithm waits for a short period (`100ms`) before retrying. This algorithm ensures that packets are sent at a controlled rate, preventing network congestion and ensuring fair usage of network resources. ## Great reads [This blog post](https://corrode.dev/blog/async/) nicely describes the pro's and con's of async Rust vs normal threads. It nicely illustrates that async Rust is not always the best choice for all use cases. Obviously we would need to include [how masscan works](https://rushter.com/blog/how-masscan-works/) And the bulk of the masscan code can be found in [main.c](https://github.com/robertdavidgraham/masscan/blob/master/src/main.c) ``` ### `book/src/network_engine/dns_resolver.md` **Absolute path:** `/root/mrpf/book/src/network_engine/dns_resolver.md` ```markdown # DNS Resolver ## Ideas and Future Work ... ``` ### `book/src/network_engine/http11_scanner.md` **Absolute path:** `/root/mrpf/book/src/network_engine/http11_scanner.md` ```markdown # HTTP/1.1 Scanner ## Ideas and Future Work ... ``` ### `book/src/network_engine/ideas/MY_FIRST_ORIGINAL_IDEA_FOR_BETTER_THROUGHPUT.md` **Absolute path:** `/root/mrpf/book/src/network_engine/ideas/MY_FIRST_ORIGINAL_IDEA_FOR_BETTER_THROUGHPUT.md` ```markdown # Idea This came to me during an outdoor run, the first one in a month or two letting myself go. A good reminder that working out really benefits you and makes life better overall. LATER UPDATE, I did some work on this: Check the following branch: My idea: I'm building a verty fast and memory efficient tcp syn scanner using rust pnet library on the datalink layer. I've split transmit and receive of packets into separate threads. The transmit thread at the moment composes a tcp packet from the ethernet layer up, and then uses the pnet datalink tx channel to send traffic out. The send function is taking up a large percentage of the total memory used. What if i would create a larger buffer where i combine multiple packets, and then finally send that full buffer using the datalink tx channel in one go. Does the network layer below that be able to send those packets in one go and would it perhaps reduce memory footprint and less switching of user and kernel space? Clearly think about your response and outline the pro's con's and potential pitfalls ## Copilot's response It sounds like you're working on an interesting and challenging project! Let's break down your idea and consider the pros, cons, and potential pitfalls of using a larger buffer to combine multiple packets before sending them. ### Pros 1. **Reduced Memory Footprint**: By combining multiple packets into a single buffer, you might reduce the overhead associated with multiple small allocations, potentially lowering the overall memory usage. 2. **Fewer System Calls**: Sending a larger buffer in one go could reduce the number of system calls, which might decrease the overhead of context switching between user and kernel space. 3. **Improved Throughput**: Aggregating packets could lead to better utilization of the network interface, potentially improving throughput. ### Cons 1. **Increased Latency**: Waiting to fill a larger buffer before sending could introduce latency, especially if the buffer size is too large or if packet generation is not fast enough. 2. **Complexity in Buffer Management**: Managing a larger buffer and ensuring that packets are correctly combined and sent can add complexity to your code. This includes handling partial sends and ensuring data integrity. 3. **Potential for Packet Loss**: If the buffer is not managed properly, there is a risk of packet loss, especially if the buffer overflows or if there are errors in combining packets. ### Potential Pitfalls 1. **Network Layer Limitations**: The underlying network layer might have limitations on the maximum transmission unit (MTU). If your combined buffer exceeds this size, it could lead to fragmentation, which might negate some of the performance benefits. 2. **Error Handling**: With a larger buffer, error handling becomes more critical. If an error occurs during the transmission of a large buffer, it could affect multiple packets, leading to more significant data loss. 3. **Synchronization Issues**: If multiple threads are involved in filling and sending the buffer, you need to ensure proper synchronization to avoid race conditions and data corruption. ### Recommendations - **Buffer Size Tuning**: Experiment with different buffer sizes to find an optimal balance between reducing memory usage and minimizing latency. - **Error Handling**: Implement robust error handling to manage partial sends and retransmissions effectively. - **Profiling and Testing**: Profile your application to understand the impact of the changes on memory usage and performance. Conduct thorough testing to ensure that the new approach does not introduce new issues. By carefully considering these factors, you can make an informed decision on whether to implement a larger buffer for sending packets. Good luck with your TCP SYN scanner project! If you have any more questions or need further assistance, feel free to ask. ## Github Copilot response in vscode to the same question GitHub Copilot: Combining multiple packets into a larger buffer and sending them in one go can indeed reduce the overhead associated with frequent system calls and context switches between user and kernel space. This approach can lead to better performance and reduced memory footprint. However, there are several pros, cons, and potential pitfalls to consider: ### Pros 1. **Reduced System Call Overhead**: By batching multiple packets into a single buffer, you reduce the number of system calls, which can be expensive in terms of CPU cycles and context switching. 2. **Improved Throughput**: Sending larger buffers can improve throughput as the network interface can handle larger chunks of data more efficiently. 3. **Lower Memory Fragmentation**: Allocating and deallocating memory for individual packets can lead to fragmentation. Using a larger buffer can help mitigate this issue. ### Cons 1. **Increased Latency** Batching packets may introduce latency as packets are held in the buffer until it is full or a timeout occurs. 2. **Complexity**: Implementing a buffering mechanism adds complexity to the code. You need to manage the buffer, handle partial writes, and ensure thread safety. 3. **Buffer Management**: You need to carefully manage the buffer size to avoid excessive memory usage or buffer overflows. ### Potential Pitfalls 1. **Buffer Overflow**: If the buffer size is not managed properly, it can overflow, leading to data loss or corruption. 2. **Thread Safety**: If multiple threads are accessing the buffer, you need to ensure thread safety using synchronization mechanisms, which can introduce overhead. 3. **Network Interface Limitations**: Some network interfaces may have limitations on the maximum buffer size they can handle. You need to ensure that the buffer size is within these limits. 4. **Error Handling**: Handling errors in a batched send operation can be more complex. You need to ensure that partial writes are handled correctly and that the buffer is properly managed in case of errors. ### Implementation Outline 1. **Buffer Initialization**: Initialize a buffer to hold multiple packets. 2. **Packet Composition**: Compose packets and add them to the buffer. 3. **Buffer Flush**: When the buffer is full or a timeout occurs, send the buffer using the datalink tx channel. 4. **Thread Safety**: Ensure that the buffer is accessed in a thread-safe manner. Here's a basic implementation outline: ```rust use pnet::datalink::DataLinkSender; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; const BUFFER_SIZE: usize = 4096; // Adjust as needed const FLUSH_INTERVAL: Duration = Duration::from_millis(50); struct PacketBuffer { buffer: Vec, last_flush: Instant, } impl PacketBuffer { fn new() -> Self { PacketBuffer { buffer: Vec::with_capacity(BUFFER_SIZE), last_flush: Instant::now(), } } fn add_packet(&mut self, packet: &[u8]) { self.buffer.extend_from_slice(packet); } fn should_flush(&self) -> bool { self.buffer.len() >= BUFFER_SIZE || self.last_flush.elapsed() >= FLUSH_INTERVAL } fn flush(&mut self, tx: &mut dyn DataLinkSender) { if !self.buffer.is_empty() { tx.send_to(&self.buffer, None).unwrap(); self.buffer.clear(); self.last_flush = Instant::now(); } } } pub(crate) fn transmit_handler( mut tx: Box, packet_rx: Receiver>, progress_tx: Sender, mut config: ScannerConfig, router_mac: MacAddr, ipv4_addresses: Ipv4Addresses, ports: Ports, ) { let packet_buffer = Arc::new(Mutex::new(PacketBuffer::new())); // Start rate limiter config.rate_limiter.start(); for dst_port in ports { let src_port: u16 = rand::thread_rng().gen_range(49152..=60999); for dst_ip in ipv4_addresses.clone() { let packet = create_syn_tcp_packet_v4(src_port, dst_port, dst_ip, router_mac, &config); { let mut buffer = packet_buffer.lock().unwrap(); buffer.add_packet(&packet); if buffer.should_flush() { buffer.flush(&mut *tx); } } } } // Ensure any remaining packets are sent { let mut buffer = packet_buffer.lock().unwrap(); buffer.flush(&mut *tx); } } ``` ### Explanation 1. **PacketBuffer Struct**: Manages the buffer and handles adding packets and flushing the buffer. 2. **Buffer Initialization**: The buffer is initialized with a capacity of `BUFFER_SIZE`. 3. **Packet Composition**: Packets are composed and added to the buffer. 4. **Buffer Flush**: The buffer is flushed when it is full or when the flush interval has elapsed. 5. **Thread Safety**: The buffer is wrapped in an `Arc>` to ensure thread safety. This implementation reduces the number of system calls and context switches, potentially improving performance and reducing memory usage. However, you need to carefully manage the buffer size and ensure thread safety to avoid potential pitfalls. ## After initial implementation It seems pnet uses the libc:sendto call to send the packets. This call expects you to send a single packet and the kernel wraps this into a single frame. This means all the appended data (aka, packets after the first packet) is sent but dropped completely by the receiving end. According the copilot, you could instead use sendmsg call to accomplish this. This does mean we won't be able to leverage pnet for this. Perhaps we need to bite the bullet and write our own datalink layer implementation. a bunch of work obviously but it would be a good learning experience. The answer in [this stack overflow](https://stackoverflow.com/questions/4258834/how-sendmsg-works) post does indeed confirm my suspicion that constructing this yourself using sendmsg *might* improve performance when doing this in user space. It states that sendto is actually just a wrapper around sendmsg. IF we're able to get the socketfd from pnet, I should be able to just add this sendmsg function myself. I'll have to look into this. Maybe even a nice PR to pnet if it works out. OK, looking at the code, on linux you can provide a socket to the config object. This means we can generate our socket ourselves which will be something like this: ```rust let socket = match unsafe { libc::socket(libc::AF_PACKET, libc::SOCK_RAW, libc::ETH_P_ALL.to_be()) } { -1 => return Err(io::Error::last_os_error()), fd => fd }; ``` Here's copilots response to sendmsg implementation: Yes, you can use the `sendmsg` system call to send multiple packets in a single call. The `sendmsg` function allows you to use a scatter/gather array, which can be used to send multiple buffers (or packets) in one go. This is done using the `msghdr` structure, which includes an array of `iovec` structures, each pointing to a different buffer. ### Example of `sendmsg` Usage Here's a simplified example in C to illustrate how you might use `sendmsg` to send multiple packets: ```c #include #include #include #include #include #include int main() { int sockfd; struct sockaddr_in dest_addr; struct msghdr msg; struct iovec iov[2]; char packet1[] = "Packet 1 data"; char packet2[] = "Packet 2 data"; // Create socket sockfd = socket(AF_INET, SOCK_DGRAM, 0); // Destination address memset(&dest_addr, 0, sizeof(dest_addr)); dest_addr.sin_family = AF_INET; dest_addr.sin_port = htons(12345); dest_addr.sin_addr.s_addr = inet_addr("192.168.1.1"); // Prepare iovec structures iov[0].iov_base = packet1; iov[0].iov_len = strlen(packet1); iov[1].iov_base = packet2; iov[1].iov_len = strlen(packet2); // Prepare msghdr structure memset(&msg, 0, sizeof(msg)); msg.msg_name = &dest_addr; msg.msg_namelen = sizeof(dest_addr); msg.msg_iov = iov; msg.msg_iovlen = 2; // Send message if (sendmsg(sockfd, &msg, 0) == -1) { perror("sendmsg"); return 1; } close(sockfd); return 0; } ``` ### Explanation - **`iovec` Structures**: These structures point to the individual packets you want to send. - **`msghdr` Structure**: This structure contains the destination address and the array of `iovec` structures. - **`sendmsg` Call**: This sends all the packets in the `iovec` array in a single system call. ### Benefits - **Efficiency**: Reduces the number of system calls, potentially improving performance. - **Flexibility**: Allows you to send multiple packets or buffers in one go. ### Considerations - **MTU Size**: Ensure that the total size of the combined packets does not exceed the MTU size to avoid fragmentation. - **Error Handling**: Implement robust error handling to manage partial sends and retransmissions. Using `sendmsg` can help you achieve your goal of sending multiple packets efficiently. If you have more questions or need further details, feel free to ask! ``` ### `book/src/network_engine/ideas/README.md` **Absolute path:** `/root/mrpf/book/src/network_engine/ideas/README.md` ```markdown # Ideas for Improvement This is a random collection of ideas I have for improving the network engine. Collected from random notes and thoughts I had lying around everywhere, trying to bring more structure to my notes. ``` ### `book/src/network_engine/ideas/TLS_parse_improvements.md` **Absolute path:** `/root/mrpf/book/src/network_engine/ideas/TLS_parse_improvements.md` ```markdown # Improvements to my certificate parse methods ## Skip any root or CA certificates When a server certificate contains more certs, ensure to only parse the FIRST one. The Certificate response has a: Certificates Length that contains total lenght. What follows is a list of - Certificate Length - Certificate Data Quote from the RFC: ```quote certificate_list This is a sequence (chain) of certificates. The sender's certificate MUST come first in the list. Each following certificate MUST directly certify the one preceding it. Because certificate validation requires that root keys be distributed independently, the self-signed certificate that specifies the root certificate authority MAY be omitted from the chain, under the assumption that the remote end must already possess it in order to validate it in any case. ``` ``` ### `book/src/network_engine/ideas/improvements_to_investigate_on_linux.md` **Absolute path:** `/root/mrpf/book/src/network_engine/ideas/improvements_to_investigate_on_linux.md` ```markdown Use SO_BUSY_POLL or PACKET_MMAP (TPACKET v3) to reduce syscall overhead. Pin RX/TX threads to dedicated cores (isolcpus, taskset). Disable CPU frequency scaling and C-states for consistent latency. Consider MSG_ZEROCOPY if you send large packets. kernel 6.12 has EEVDF which may help with networking latency. cat /sys/kernel/debug/sched/features | grep EEVDF ## Pin CPU's to my send and receive threads - Improve latency by increasing the priority: ```bash # Lower nice value = higher priority sudo renice -n -10 -p ``` - Add to your kernel boot params: ```bash isolcpus=nohz,domain,managed_irq, ``` - Figure out how to pin a thread to a core in Rust. - Disable frequency scaling? ```bash cpupower frequency-set -g performance ``` ``` ### `book/src/network_engine/ideas/recucing_pselect_calls.md` **Absolute path:** `/root/mrpf/book/src/network_engine/ideas/recucing_pselect_calls.md` ```markdown # Reducing pselect calls Looking at the flame graph, the receive thread spends quite some time in the pselect system call (about half). Obviously since it uses a timeout, it could be there are just no packets to receive. However, looking at the implementation, it seems the [pnet code](https://github.com/libpnet/libpnet/blob/a01aa493e2ecead4c45e7322b6c5f7ab29e8a985/pnet_datalink/src/bpf.rs#L300) Is only reading a single packet when pselect is ready. Instead, we should try to read more packets to avoid extra pselect calls when more packets are ready. UPDATE: No actually it seems the code IS reading multiple packets. It's dependent on the read buffer size of the datalink, so we could perhaps tweak performance a bit by playing around with the buffer size there. ``` ### `book/src/network_engine/ideas/refactor_network_stack.md` **Absolute path:** `/root/mrpf/book/src/network_engine/ideas/refactor_network_stack.md` ```markdown # Refactor network stack Remove dependency on libpnet and re-implement linux and macOS code myself. This will prepare me to choose the most optimal transmission method for the two platforms and for instance use XDP/eBPF to reduce kernel signals. I suspect the following library has a lot of the code that I could reuse for more flexibility to dig into the way we're sending packets: This will probably also avoid me from receiving packets not destined for me as we'll be having a dedicated interface. I have to look at [this part](https://github.com/pkts-rs/tappers/blob/master/src/macos/utun.rs), it uses writev to send a packet. It states that only ip and ipv6 are supported so I think it means there's no ethernet support. At the moment, the only reason for having ethernet support is to be able to spoof my ip address. Since we're going to have a dedicated virtual interface, we can just set the ip address of the interface to the one we want to spoof and we should be good to go. Had a quick look at the race condition attack and that uses ip and tcp, so we should be able to build that already. It's also really nice as it already splits the interface into tun(send) and tap(receive) devices. This will make it easier to split it between threads. Before doing this, I should carefully read [this blog post](https://blog.cloudflare.com/virtual-networking-101-understanding-tap/), as it might be that the tun/tap device actually has MORE syscalls than the current implementation. If a TUN/TAP device is not the way to go, look at libpnet and re-implement the relevant parts myself. ``` ### `book/src/network_engine/ideas/tls_parse_improvements.md` **Absolute path:** `/root/mrpf/book/src/network_engine/ideas/tls_parse_improvements.md` ```markdown # TLS parse improvements ``` ### `book/src/network_engine/infrastructure.md` **Absolute path:** `/root/mrpf/book/src/network_engine/infrastructure.md` ```markdown # Infrastructure Since we require privilege for using raw sockets, we no longer can run on AWS Lambda. However, fargate with spot pricing could be a good alternative. We have to think about how to deploy this. AWS Batch could help, or we can create an sqs queue that will hold our tasks, then after pushing our tasks we'll start up x amount of fargate tasks to retrieve stuff from sqs and kill them after a single job. ## AWS Lambda ARM - 0.0000000017 per 128mb per ms - 0.0000017 per 128mb per second - 0.000102 per 128mb per minute - 0.0000000267 per 2048mb per ms - 0.0000267 per 2048mb per second - 0.0016 per 2048mb per minute - pay per second - size measured in memory ## Fargate spot pricing ARM - 0,00016429 per vCPU per minute - pay per minute - size measured in vCPU, minimum = 2Gb mem This means lambda is 0,00016429 - 0.000102 = 0.00006229 CHEAPER per vCPU per minute, when we don't care about memory. This means lambda is 0,0016 - 0,00016429 = 0.00143571 MORE EXPENSIVE per vCPU per second, when we compare to 2048mb lambda. Lets say we will run 10 tasks for a full hour a day for every day of the month. With lambda 2048mb we would pay 0.0016 *60* 10 * 31 = $29.76 With fargate we would pay 0,00016429 * 60 * 10 * 31 = $3.055794 UPDATE, Fargate cost described above is actually still missing the GB per hour. BUT I also see fargate tasks have a lower minimum, the minimum is 0.25vCPU with 0.5GB. They have a fargate pricing calucation example (I don't think it uses spot instances even!): 5 tasks running on ARM for 10 minutes every day, with 1vCPU and 2GB mem, for the whole month cost a total of $1.02. That is missing data transfer cost and public ip cost but still, I think we can work with that!! AWS Batch will be great for this as well. ``` ### `book/src/network_engine/os_network_stack_challenges.md` **Absolute path:** `/root/mrpf/book/src/network_engine/os_network_stack_challenges.md` ```markdown # OS Kernel stack interferes with custom network engine Summary: - Either we spoof our IP. This is preferred where possible as it will minimize us having to parse traffic not meant for the scanner. However, on AWS EC2 this will not be possible as the Internet Gateway will do NAT translation for the source private ip to the public ip of the instance. So replies will not come back to us if we spoof a non-local ip. - Use firewall rules (iptables or nftables) to drop RST packets going out from the OS kernel for the source ports we are using in our custom network engine. This way the kernel will not interfere with our crafted packets. Apparently this is how Masscan does it. This does also mean we can't send RST packets either? Actually the masscan page says linux maintains a llist of ports they use /proc/sys/net/ipv4/ip_local_port_range so we can pick ports outside of that range to avoid interference from the OS kernel. ## ChatGPT question and answer I've written a high performance network engine where the sending of packets and receiving of them happen in separate threads. Using syn_cookies we're tracking connections in our own table. We use raw sockets to send and receive packets. It seems that the OS itself is interfering with the packet streams and sending RST packets to connections i'm trying to establish. The OS we're running this on is Linux (EC2 ARM64 instances). Is there a way to stop the linux network stack from interfering with my traffic? GitHub Copilot: You can’t “turn off” the kernel TCP implementation, but you can prevent it from injecting RSTs (and other corrective packets) for traffic you craft with raw/AF_PACKET sockets. Main approaches: 1. Firewall drop outbound RST (simplest, masscan-style) sudo iptables -t raw -I OUTPUT -p tcp --tcp-flags RST RST -j DROP (Use -t raw so it’s early, avoids conntrack overhead.) Limit to your scan source port range to avoid breaking normal connections: sudo iptables -t raw -I OUTPUT -p tcp --sport 40000:50000 --tcp-flags RST RST -j DROP Remove afterward: sudo iptables -t raw -D OUTPUT -p tcp --sport 40000:50000 --tcp-flags RST RST -j DROP 2. nftables equivalent (preferred on newer systems) sudo nft add table inet raw sudo nft add chain inet raw output '{ type filter hook output priority -300; }' sudo nft add rule inet raw output tcp sport 40000-50000 tcp flags & rst == rst drop Cleanup: sudo nft delete table inet raw 3. Separate network namespace Put the scanner in its own netns with its own IP; apply RST-drop rules only there: sudo ip netns add scan sudo ip link set dev eth0 netns scan (Configure IP/route inside the namespace, then apply nft/iptables rule.) This isolates side effects from the host’s normal traffic. 4. eBPF/XDP drop Attach a minimal XDP program on the interface to drop RST frames generated locally (or SYN-ACK responses you don’t want the kernel to process). Higher performance, but more engineering. Only worth it if firewall rule becomes a bottleneck at very high PPS. 5. Spoof non-local source IPs Not usually practical unless you control routing and disable reverse path filtering (sysctl net.ipv4.conf.all.rp_filter=0). Still need a way to receive replies; typically not recommended unless you’re doing wide Internet scans and have proper authorization. Notes: - Dropping all RST can leave remote hosts thinking connections are half-open longer; limit by port range/time. - Keep source ports in a confined range so filtering is easy and predictable. - Disable any automatic packet offloads that might interfere (ethtool -K eth0 tso off gso off gro off lro off) if you see unexpected segmentation behavior. - For UDP scanning, similarly block ICMP Port Unreachable if needed (iptables -I OUTPUT -p icmp --icmp-type port-unreachable -j DROP). Revert rules after scanning to restore normal TCP behavior. That’s the practical way: filter out the kernel’s corrective packets rather than trying to disable the stack. ``` ### `book/src/network_engine/tcp_syn_scanner.md` **Absolute path:** `/root/mrpf/book/src/network_engine/tcp_syn_scanner.md` ```markdown # Tcp Syn Scanner ## Ideas and Future Work ... ``` ### `book/src/network_engine/tls_scraper.md` **Absolute path:** `/root/mrpf/book/src/network_engine/tls_scraper.md` ```markdown # TLS Scraper ## Ideas and Future Work ... ``` ### `book/src/network_engine/whois_resolver.md` **Absolute path:** `/root/mrpf/book/src/network_engine/whois_resolver.md` ```markdown # Whois Resolver ## Ideas and Future Work ... ``` ### `book/src/security_ideas/README.md` **Absolute path:** `/root/mrpf/book/src/security_ideas/README.md` ```markdown # Bug Bounty Ideas This section contains random ideas I have regarding bug bounty hunting, security research and general security topics. ``` ### `book/src/security_ideas/ai_wordlists.md` **Absolute path:** `/root/mrpf/book/src/security_ideas/ai_wordlists.md` ```markdown # AI-Generated Wordlists - Build instructions for different attack vectors or reconnaissance techniques. - Provide a target (e.g. a domain, ip range or a MRPF Target object), existing wordlist and the build instruction to an AI model and let it generate a new wordlist. Store the generated words and run your scans. Gather statistics on the results and repeat the process. ## Good articles to get ideas from for instructions - finding actuators [here](https://www.dsecured.com/en/articles/spring-boot-actuator-using-misconfig-to-your-advantage-paths-bypasses-techniques), [here](https://www.wiz.io/blog/spring-boot-actuator-misconfigurations) and [here](https://blog.wss.sh/how-i-made-more-than-30k-with-jolokia-cves/) ``` ### `book/src/security_ideas/aws_protections.md` **Absolute path:** `/root/mrpf/book/src/security_ideas/aws_protections.md` ```markdown # Protections implemented by AWS for scanning Blackfoot is the EC2 elastic networking backend AWS MadPot: - When they detect a connection to a ip that’s not allocated to any Elastic IP they will know it’s a scanner so they can mark your ip. This is not great for my TLS scanner… - Sonaris is the S3 protection Important part will be if you can run your scanner on AWS infrastructure that uses dynamic outbound public ips, they very likely won’t block your IP as that would mean they could impact other AWS customers. I'm not sure where I got the above information anymore, I think it was in a podcast somewhere from Critical Thinking? ## Summary of some services Learned this during some discussions - Mithra - network to inspect DNS request per region. Is also give a 'benign reputation source',that guardduty uses to prevent false positives. Route53 domain blocking also uses Mithra, perhaps also some AWS internal services use it but wasn't really clear. - MadPot - Think a more standard honeypot solution. When it detect a proper validated attack, it can replicate the blacklisted ips to the whole network. - Blackfoot - analyze all inbound and outbound flows (13T flows an hour) to VPCs. How many of these come from malicious ips, and they use MadPot to determine if it's a really malicious ip. - Sonaris - internal threat intelligence tool looks at network traffic and find potential security threats. Finds attempts of people trying to find public buckets, vulnerable services on EC2, etc. - SnowTube, what public IPs are associded with EC2. is published to an SNS topic. Would be beautiful if we can subscribe to this topic?!! are there explicit accounts or Org conditions? Can we levarage AWS services to listen to this topic? How can we find out the name of the SNS topic? - IP Ownership. this is a service managed by EC2 team - EC2 which ip addresses are associated with what instances for a point in time ## How does GuardDuty work? - S3 malware uses bitdefender to help with hashes. They also have a few other internal rules for it. - GuardDuty gathers all required data itself, does not need you to enable it (vpc flow logs, cloudtrail, dns logs, S3/RDS access logs) ### GuardDuty infrastructure GuardDuty is built using a lot of the 'normal' AWS services, like Lambda, S3, EC2, RDS, Firehose. - Frontend running in customer account, these are the actual resources that will be checked. - Non guardduty internal components, S3, Route53 logs, flow logs, service logs from s3, eks audit logs, IP Ownership, (this is a service managed by EC2 team - EC2 which ip addresses are associated with what instances for a point in time), Mithra (DNS inspection) Their evaluation components: - Stateless processor: evaluate, this is related to the threat intel providers. eg. ip ownership, external vendor intel (croudstrike and proofpoint are definitely used), etc - Stateful processing: This is where machine learning models are applied, what kind of things can it detect - Malware engine: Another service: - Account service: Which accounts do you have enabled guardduty on, what is the delegate account, what features are enabled? etc? Security boundaries: GuardDuty runs internally across a whole lot of these above 'micro'services. They spread their services into different accounts, using it as a security boundary. Often they just use IAM roles and resource policies to control this, they don't put everything behind API gateways etc. ## DNS Graph statistics They get all their data from Route53 to build their mitigfations (200TB DNS logs with 5B subdomain nodes oktober-2025). Domain (TLD + 1) -> CNAME -> DNS Subdomain -> DNS -> EC2 Instances Subdomain -> DNS -> AWS Account Domain reputation pipeline: 1. Create a graph for the Domain target 2. Train models on the graph 3. Evaluate models using ?? ### Firenze for Model Evaluation What are the manual sec engineer steps eg. for evaluation domains. - New domain comes, are any ip addresses already sinkholed, so more likely to be malicious. Is it low popularity, is it nonsense, is the TLD abused often? Firenze will use the signals that sec engineers generate or engineers identify new of these weak signals, to better evaluate a model and provide guardrails. This is used to improve Mithra. There is a whitepaper firenze-model-evaluation-using-weak-signals ### High level getting findings into GuardDuty - ingest signals, apply ETL - Signals delta table -> Clustering - Clustering - Compoind signals Delta Table -> Pripritization - Compoind signals Delta Table -> training - Prioritization -> attach sequences (s3) - attach sequences (s3) -> Security Hub - attach sequences (s3)-> Finding Gateway -> Findings into the API for GuardDuty console. ``` ### `book/src/security_ideas/aws_web_identity.md` **Absolute path:** `/root/mrpf/book/src/security_ideas/aws_web_identity.md` ```markdown # AWS default web identity providers AWS offers default web identity providers for the AssumeRoleWithWebIdentity API. These providers are pre-configured to work with popular identity providers like Amazon Cognito, Google, and Facebook. Scan Javascript files for the known AWS SDK function calls that perform AssumeRoleWithWebIdentity. If you find one, try to assume it from google, facebook or cognito. ``` ### `book/src/security_ideas/random_ideas.md` **Absolute path:** `/root/mrpf/book/src/security_ideas/random_ideas.md` ```markdown # Future features Note Jeff Bezos last bit of [this youtube video](https://x.com/epictrades1/status/1864451316254155040), thinking small is a self fulfilling prophecy. LETS BUILD THE BEST BUG BOUNTY TOOLS IN THE WORLD! THIS INCLUDES ULTRA FAST AND SCALABLE SCANNERS, AND A BEAUTIFUL GUI TO MANAGE THEM! PROXY SERVER, REPEATER, BIG CONTINUOUS SCRAPERS, CERT TRANSPARENCY MONITORING, DNS RESOLVER, AND MORE! IT WILL BE THE BEST AND CHEAPEST TOOL IN THE WORLD! BUILT USING RUST AND LEVERAGING THE BEST SCALABLE AND CHEAPEST AWS CLOUD SERVICES. ## Create HTTP service that exposes our scanners Similar to how Caido works, build a web frontend for the scanner that can kick off scans. This will open up a lot of flexibility in how we can use the scanner, and is very similar to how Caido is built: - Run scanner on localhost and get your SwiftUI app to interface into it - Run the scanner in a container or VM and kick off scans - Create another web service that acts as an aggregator of scanners. Manage multiple scanners through one place. The scanners will provide callbacks to the aggregator service. The GUI can contact the aggregator service to get the status of all scanners and push jobs to other scanners. We have to make sure that the web services are not directly tied into the scanner code logic itself. Otherwise we would not be able to easily run the scanners from serverless functions or other code bases. ## Http1/1 and H2 packet generation using the pnet packet marco Is there already, or can I leverage the pnet packet macro for adding http1 and http2 support to pnet? That will make it easier to generate http packets. ### TCP Fast Open See if we want to implement TCP Fast Open for SNI scanning. Linux apparently does support it by default, windows doesn't. This could reduce the round trip time of SNI scraping tasks. ### Race condition testing implement a scanner that does this: Since this technique relies on crafting IP re-assembled packets and using TCP sequence numbers in a particular way, it will likely not be able to leverage the normal Engine as that leverages syn cookies to sync the transmit and receiving threads. ### HTTP1 Pipelining Can I use my engine to implement HTTP1 pipelining? ### HTTP-to-DNS Use my engine to very quickly resolve domains against CloudFlare. Can we make this very memory efficient as most packets will look very much alike. I could perhaps use the streaming json library ### Caido/Burp repeater using my HTTP engine Integrate my HTTP engine with websocket frontend into MPF iOS app. I'm not going to create a proxy (at least not yet :D) but something similar to the repeater feature in Burp. - One part that allows you to craft the binary request - Within the request I can use my template variables - I can select existing wordlists to use for my template variables - I can set rate limits for the scanner - Starting the scan will show the results in a table view - Doubleclick a result will drill into the result. Back arrow goes back to the table view - Easily import common HTTP request payloads This would actually be a feature that could help me find bugs (a bit) faster. At the moment a lot of my time is spent inside the repeater tab, manually crafting payloads. If I have all my wordlists and templates ready, I can way more easily use them. That in combination with my fast scanner could be a very powerful tool (of course would have to be careful with rate limits). Since I can have my scanner running somewhere else, I can also use it from my phone. Of course thats a bit less screen space but at least I could get some stuff to work. Of course, when this is built, whats stopping me from creating my own proxy? :D there I can also create a websockets frontend and then get a history in my iOS tool and send requests to the repeater. ``` ### `book/src/security_ideas/s3_dns_recon.md` **Absolute path:** `/root/mrpf/book/src/security_ideas/s3_dns_recon.md` ```markdown # Tricks using DNS lookups for recon Read this for discovering S3 buckets using DNS enumeration: - ``` ### `book/src/system_architecture/README.md` **Absolute path:** `/root/mrpf/book/src/system_architecture/README.md` ```markdown # System Architecture The system architecture of the MRPF project can be broken down into a few key components: ## MRPF API The MRPF API serves as the interface for clients to interact with the collected data and manage the scheduling of tasks. ## Task Manager The Task Manager is responsible for scheduling and managing various scanning tasks. It handles task creation, execution, and monitoring. ## Network Engine The Network Engine a fast network scanning engine based on masscan. It is responsible for sending and receiving network packets using a separate receive and transmit thread. The engine itself exposes traits which can be used to implement different scanning techniques. ### TCP SYN Scanner The TCP SYN Scanner is a specific implementation of a scanning technique that utilizes TCP SYN packets to identify open ports on target hosts. It leverages the Network Engine for packet transmission and reception. ### HTTP/1.1 Scanner The HTTP/1.1 Scanner is designed to perform HTTP/1.1 requests to target hosts and analyze the responses. It can be used to identify web servers, gather information about web applications, and detect potential vulnerabilities. ### TLS Certificate Scanner The TLS Certificate Scanner is responsible for retrieving and analyzing TLS/SSL certificates from target hosts. It can be used to identify the certificate issuer, expiration dates, and potential vulnerabilities in the SSL configuration. It uses the Network Engine for network communication and has a custom TLS implementation to extract certificate information without relying on external libraries. ### DNS Resolver NOT IMPLEMENTED YET The DNS Resolver is responsible for querying DNS servers to resolve domain names to IP addresses. It utilizes the Network Engine to perform these queries and gather relevant data. ### Whois Resolver NOT IMPLEMENTED YET The Whois Resolver is responsible for querying Whois databases to retrieve information about domain names and IP addresses. It utilizes the Network Engine to perform these queries and gather relevant data. ### Models The various components of the MRPF project utilize a shared set of models defined in the `mrpf_models` crate. These models define the data structures and types used throughout the system, ensuring consistency and interoperability between different components. #### Iterator Models Most of the models are reasonably straightforward. However, the `Ipv4Range`, `Ipv4Addresses`, `Ports` and `PortRange` models deserve special attention. These models implement a custom Iterator that shuffles the order of IP addresses and ports to avoid predictable scanning patterns. The algorithm idea was taken from [masscan](https://github.com/robertdavidgraham/masscan). The iterator ensures each item is returned only once but avoids having to store all items in memory at once. It accomplishes this by only storing the start and end values within the `Ipv4Addresses` and `PortRange` models. The iterator uses a Feistel cipher to generate a pseudo-random permutation of the range of values, allowing for efficient iteration without repetition. The low memory footprint is very useful in our task manager as it reduces the SQS message sizes, the database storage requirements and the RAM usage of the workers. In the future we may introduce similar iterators, for instance for domain names. When trying to fuzz for new subdomains, we could reduce memory footprint by storing it as a hierarchical structure instead of a flat list. ``` ### `book/src/task_manager/README.md` **Absolute path:** `/root/mrpf/book/src/task_manager/README.md` ```markdown # Task Manager The task manager is a core component of the MRPF system, responsible for orchestrating and managing the execution of various tasks. It provides a flexible and extensible framework for defining, scheduling, and executing tasks, as well as handling task dependencies and data management. It takes inspiration from Nuclei's templating system, but is fully integrated with the rest of the MRPF system. This means we're able to run any scan type we've implemented, pull data like wordlists, target scope regexes, apex zones, and use that data in our tasks. It also allows for the fan-out/fan-in pattern, to distribute tasks across multiple workers, and aggregate results from multiple tasks. We support complex execution workflows like parallel execution, conditons(*) and loops(*) It follows the [core principles](core_principles.md) of the MRPF system. For instance, the HTTP1/x module doesn't hide the differences of IP addresses, host names and TLS SNI behind a single "host" abstraction. Here's an example of what this looks like in a task template: ```yaml - kind: http1 ip: $[ipv4_address_ranges] port: 443 tls: true sni: localhost body: | GET / HTTP/1.1 Host: localhost ``` The observant reader will see that even here we're taking some liberties. For instance, the body is treated as UTF-8 in this example. We will offer various other ways to define the request body in the future, such as hex encoded binary data to allow complex attacks on encodings or specific byte sequences. (*) Actually not yet implemented :) ``` ### `book/src/task_manager/data.md` **Absolute path:** `/root/mrpf/book/src/task_manager/data.md` ```markdown # Data Management Each task collection run maintains temporary state within the tasks_data table. This allows tasks to build on previous outputs. The task templating variables allow you to define how data is passed between tasks. ``` ### `book/src/task_manager/examples.md` **Absolute path:** `/root/mrpf/book/src/task_manager/examples.md` ```markdown # Examples Here are some examples of task collection structures to illustrate different use cases. Note that this is the YAML representation of how developers would define these task collections. The MRPF API will convert these definitions in a slightly different JSON internal representation to allow for the various features like task generators, data aggregations, conditionals and loops. ```yaml - name: Example Task Collection description: An example task collection demonstrating various features. tasks: - sequential: - kind: get_target target_id: "victim" output: target - parallel: - kind: tcp_syn_scanner ipv4_address_ranges: "$[target.ip_ranges]" ports: "80,443,8080-8089" output: open_ports - kind: dns_lookup domains: "$[domains[*].fqdn]" record_types: A output: domains - if: - when: contains: { var: "$[open_ports]", value: "443" } then: - kind: http_fuzz method: GET host: "$[target.domain]" sni: "$[target.domain]" tls: true content: | GET $[paths] HTTP/1.1 Host: $[target.domain] output: http_responses - when: contains: { var: "$[open_ports]", value: "80" } then: - kind: http_fuzz method: GET host: "$[target.domain]" tls: false content: | GET $[paths] HTTP/1.1 Host: $[target.domain] output: http_responses - else: - kind: notification message: "No HTTP ports open on $[target.domain], skipping HTTP fuzzing." - if: - when: non_empty: "$[http_responses]" then: - loop: condition: "$[http_responses[*].status_code]" do: - kind: notification message: "Received status code $[item] from $[target.domain]" - else: - kind: notification message: "No HTTP responses for $[target.domain]" - loop: condition: "$[http_responses[*].status_code]" do: - kind: notification message: "Received status code $[item] from $[target.domain]" ``` ``` ### `book/src/task_manager/infrastructure.md` **Absolute path:** `/root/mrpf/book/src/task_manager/infrastructure.md` ```markdown # Infrastructure The system leverages AWS services for scalability and security, with SQS for event-driven invocations. ```mermaid %%{init: { "theme": "base", "themeVariables": { "primaryColor": "#0b5cab", "primaryTextColor": "#ffffff", "lineColor": "#6b7280", "tertiaryColor": "#eef2ff", "fontFamily": "Segoe UI, Roboto, Helvetica, Arial, sans-serif" }, "flowchart": { "diagramPadding": 8, "curve": "basis" } }}%% flowchart TD %% --------------------------- %% Nodes (defined first) %% --------------------------- EB[EB Scheduler] TM[Task Manager] WP1[Workers] WP2[Workers] SQS[SQS Queue] TT[(Task Table)] DT[(Data Table)] ST[(Statistics Table)] %% --------------------------- %% Subgraphs / groupings %% --------------------------- subgraph PG[PostgreSQL] ST TT DT end subgraph VPC[AWS VPC] TM PG WP2 end subgraph EXT[External] WP1 end %% --------------------------- %% Edges %% --------------------------- EB -->| Check task timeouts 5 min | SQS EB -->| Gather statistics 1 hour | SQS EB -->| Cleanup old tasks and data 1 day | SQS WP1 -->| Push completion | SQS WP2 -->| Push completion | SQS SQS -->| Trigger invoke | TM TM -->| Store statistics | ST TM -->| Manage tasks | TT TM -->| Mutate data | DT WP2 -->| Mutate data | DT TM -->| Dispatch tasks | WP1 TM -->| Dispatch tasks | WP2 %% --------------------------- %% Styling %% --------------------------- %% Database tables as cylinders (already set by [( )]); add color: classDef db fill:#89CFF0,stroke:#0096FF,stroke-width:1px,color:#3b2f00; class ST,TT,DT db; classDef rounded rx:8,ry:8,stroke:#2b5fab,stroke-width:1.2px,fill:#0b5cab,color:#ffffff; class TM,WP1,WP2 rounded; classDef roundedInfra rx:8,ry:8,stroke:#F36717,stroke-width:1.2px,fill:#E25606,color:#ffffff; class EB,SQS roundedInfra; %% Subgraph backgrounds & borders style PG fill:#fff6e5,stroke:#ff8c00,stroke-width:2px,rx:10,ry:10 style VPC fill:#f0f7ff,stroke:#0b5cab,stroke-width:1.5px,rx:10,ry:10 style EXT fill:#f7f7f7,stroke:#9ca3af,stroke-width:1px,rx:10,ry:10 %% Links (edges) linkStyle default stroke:#6b7280,stroke-width:2px ``` ``` ### `book/src/task_manager/initial_description_whilst_building.md` **Absolute path:** `/root/mrpf/book/src/task_manager/initial_description_whilst_building.md` ```markdown # Task Manager I’m building a task manager in Rust using sqlx and Postgres, leveraging AWS Lambda, SQS and EventBridge. ## Concepts - A task manager schedules tasks for execution - The task manager receives instructions through a queue - The task manager manages state through access to the PostgreSQL tasks table - Each instruction can be run asynchronously and out-of-order. Task order is maintained by leveraging PostgreSQL features - Each task has a unique identifier, a completion_policy (fail_on_failure, continue_on_partial_failure, etc) and a status (pending, running, completed, failed) - We differentiate between container and leaf tasks - Container tasks are used to group other leaf or container tasks - Leaf tasks typically have results that need to be stored in the data table after completion - Leaf tasks run on workers - Workers reside outside of the VPC containing the database. Data to be stored in the data table must be sent through the SQS queue. These workers are preferred as they have a larger public ip pool, lower cost (no VPC/NAT GW traversal) and can scale more easily. - Workers inside the VPC are available for tasks that directly mutate data. This allows for more efficient data mutations when working with large payloads. - Leaf tasks are never allowed to directly mutate task information inside the database. Instead, all mutations flow through the SQS queue triggering the task manager. - Container tasks determine if it's children tasks should run sequentially or in parallel. - Container task status is calculated based on the status of it's children tasks - Leaf tasks can instruct the task manager to create new container or leaf tasks into a destination container task. - Creation of new tasks is done through messages in the task manager queue. It will generate create_task messages for the task manager to process. The task_completed message for the task generating the new task messages will include the total amount of tasks that should be created. The task manager will use this information to track the progress of task creation inside the PostgreSQL tasks table and ensure all tasks are accounted for before continuing. ## Infrastructure The system leverages AWS services for scalability and security, with SQS for event-driven invocations. ```mermaid %%{init: { "theme": "base", "themeVariables": { "primaryColor": "#0b5cab", "primaryTextColor": "#ffffff", "lineColor": "#6b7280", "tertiaryColor": "#eef2ff", "fontFamily": "Segoe UI, Roboto, Helvetica, Arial, sans-serif" }, "flowchart": { "diagramPadding": 8, "curve": "basis" } }}%% flowchart TD %% --------------------------- %% Nodes (defined first) %% --------------------------- EB[EB Scheduler] TM[Task Manager] WP1[Workers] WP2[Workers] SQS[SQS Queue] TT[(Task Table)] DT[(Data Table)] ST[(Statistics Table)] %% --------------------------- %% Subgraphs / groupings %% --------------------------- subgraph PG[PostgreSQL] ST TT DT end subgraph VPC[AWS VPC] TM PG WP2 end subgraph EXT[External] WP1 end %% --------------------------- %% Edges %% --------------------------- EB -->| Check task timeouts 5 min | SQS EB -->| Gather statistics 1 hour | SQS EB -->| Cleanup old tasks and data 1 day | SQS WP1 -->| Push completion | SQS WP2 -->| Push completion | SQS SQS -->| Trigger invoke | TM TM -->| Store statistics | ST TM -->| Manage tasks | TT TM -->| Mutate data | DT WP2 -->| Mutate data | DT TM -->| Dispatch tasks | WP1 TM -->| Dispatch tasks | WP2 %% --------------------------- %% Styling %% --------------------------- %% Database tables as cylinders (already set by [( )]); add color: classDef db fill:#89CFF0,stroke:#0096FF,stroke-width:1px,color:#3b2f00; class ST,TT,DT db; classDef rounded rx:8,ry:8,stroke:#2b5fab,stroke-width:1.2px,fill:#0b5cab,color:#ffffff; class TM,WP1,WP2 rounded; classDef roundedInfra rx:8,ry:8,stroke:#F36717,stroke-width:1.2px,fill:#E25606,color:#ffffff; class EB,SQS roundedInfra; %% Subgraph backgrounds & borders style PG fill:#fff6e5,stroke:#ff8c00,stroke-width:2px,rx:10,ry:10 style VPC fill:#f0f7ff,stroke:#0b5cab,stroke-width:1.5px,rx:10,ry:10 style EXT fill:#f7f7f7,stroke:#9ca3af,stroke-width:1px,rx:10,ry:10 %% Links (edges) linkStyle default stroke:#6b7280,stroke-width:2px ``` ## Questions - Analyze infrastructure and concepts - Draw me examples of the task tree structure. - Illustrate the flow through a set of tasks using mermaid state diagrams - Create a mermaid entity relationship diagram for the database ## Database schema ```mermaid erDiagram %% Enums task_kind_enum { TEXT Container TEXT TcpSynGenerator TEXT TcpSyn TEXT TcpSynAggregator TEXT TertiaryFilter TEXT WordFilter } task_state_enum { TEXT Pending TEXT Running TEXT PendingTaskCreation TEXT PartiallyFailed TEXT Succeeded TEXT Failed TEXT Disabled TEXT Timeout } task_execution_mode_enum { TEXT Sequential TEXT Parallel } completion_policy_enum { TEXT FailOnFailure TEXT ContinueOnPartialFailure TEXT ContinueOnFailure } worker_requirements_enum { TEXT DataTableAccess } %% Tables tasks { UUID id PK task_state state UUID parent_task_id FK JSONB definition SMALLINT expected_task_count UUID[] created_task_ids TIMESTAMPTZ updated_at } task_templates { UUID id PK task_kind kind TEXT name TEXT description BOOLEAN reusable INTEGER version JSONB config TIMESTAMPTZ updated_at } task_containers { UUID id FK task_execution_mode execution_mode } task_generators { UUID id FK SMALLINT expected_task_count UUID[] created_task_ids } task_aggregators { UUID id FK UUID aggregate_container_id FK } task_data { UUID id FK TEXT key JSONB value } %% Relationships tasks ||--o{ tasks : parent_task_id task_containers ||--|| tasks : id task_generators ||--|| tasks : id task_aggregators ||--|| tasks : id task_aggregators }o--|| task_containers : aggregate_container_id task_data ||--|| tasks : id ``` ## Ideas and Future Work We should have a Worker crate that brings all the workers together. - Signal hardware capability - Signal what capabilities it has - Run various tasks in parallel potentially - Put capabilities (aka, all of our scanner types) into features to allow us to build smaller dedicated binaries Perhaps this isn't needed, I thought wrote about this before I had my initial task_manager rust code up and running. ``` ### `book/src/task_manager/task_generators.md` **Absolute path:** `/root/mrpf/book/src/task_manager/task_generators.md` ```markdown # Task generation and aggregation The task manager allows for dynamic generation of distributed tasks based on data captured during a task collection run. For example, we might capture new domain names from scraping TLS certificates of known hosts on a target. Then we can feed these new domains into task generator for DNS resolution. The Task generator will create the required tasks in the running task collection to a destination container. These dynamically generated tasks typically will be run in parallel across multiple workers to avoid rate limits and speed up the process. The output of these tasks are all the same and often we want to aggregate the results into a single output destination. For this we introduce task aggregators. These task aggregators are also created by the generator task and put in the queue of the task collection to be run after the generated tasks are done. The aggregator will then collect all the results from the generated tasks and aggregate them into a single output. ## Ideas and Future Work ... ``` ### `book/src/task_manager/tasks_and_containers.md` **Absolute path:** `/root/mrpf/book/src/task_manager/tasks_and_containers.md` ```markdown ``` ### `book/src/task_manager/templates.md` **Absolute path:** `/root/mrpf/book/src/task_manager/templates.md` ```markdown # Templating Our Task Manager has a built-in templating engine that allows you to provide static data, references to data stored in the task collection database table and the use of functions for basic data transformations. The rendered result of a template is a `DataKind`, which will always be a collection of values of the same type. This is even the case if there's just a single value, it will be wrapped in a set. ## Literals The most basic type of expression is a literal. This allows you to provide a static value directly to a task argument. Some examples are: - String literal: `some_literal_string` - Integer literal: `42` ## References References allow to retrieve data from various places in the system or task definition. The basic syntax for a reference is as follows: `$[:]` Where `` indicates the type of the reference and `` is the identifier of the specific data you want to reference. At the moment we support Data references and Task Parameter References. Since data references are the most commonly used, the reference type can be omitted and it will default to a data reference. ## Data References Data references allow you to reference data stored in the task collection's data storage. During runtime of a task, these references will be resolved to their actual values from the database. The basic syntax for a data reference is as follows: `$[some_key]` Where `some_key` is the key of the data within a task collection you want to reference. When providing an expression to a task argument, the task manager will evaluate the expression at runtime and replace it with the corresponding value from the task collection data storage (retrieved from the PostgreSQL database). A lightweight version of JSONPath syntax is supported for accessing nested data structures. Here are some examples: - Retrieve a nested key `$[another_key.hello]` - Access an array element by index `$[another_key.hello[0]]` TODO: THis example doesn't match my another_key data example, fix it - Return an array of values from an array of objects `$[another_key.how.are.[*].you]` or `$[another_key.how.*.you]` Our database supports a variety of data types, including set of strings, JSON objects and predefined common models like IPv4 Ranges and Domains. When referencing data, the task manager will automatically handle type conversions as needed to ensure that the data is in the correct format for the task argument. Here is an example of some data that might be stored in a task collection. Specific object properties can be extracted through the JSONPath syntax mentioned above. | key | kind | value | |----------------|---------------|---------------------------------------| | some_key | set of strings | ["some", "string"] | | another_key | generic_object | {"Hello": ["World", "Moon"], "how": ["are": [{"you": "doing"}, {"things": "going"}]]"}, ...} | | some_ips | ipv4_ranges | [{"start": "127.0.0.1", "end": "127.0.0.1"}, ...] | | some_domains | set of domains | [{"fqdn": "example.com", "is_active": True, "dns_chain": ["a.com", ["10.0.0.1", 10.0.0.2]]}, ...] | ### Task Parameter References In certain cases you want to be able to reference to parameters defined within the task itself. This can be especially handy if you want to re-use standard values for different arguments. One example of this is when performing HTTP requests. Often a simple GET request will have the same content. It might look something like this: ```http GET / HTTP/1.1 Host: example.com ``` Lets say we want to fuzz across different SNI values. We could construct a standard body and use self referencing data variables to insert the SNI value in both the TLS SNI field and the Host header. ```yaml - kind: http1 ip: 10.0.0.1 port: 443 tls: true sni: $[host_fuzzing] content: | GET / HTTP/1.1 Host: $[task:sni] ``` Combining this with our data references, we could create standard content payloads and re-use them across multiple tasks. ```yaml - insert_data: key: http_get_content kind: string value: | GET / HTTP/1.1 Host: $[task:sni] - kind: http1 ip: 10.0.0.1 port: 443 tls: true sni: $[host_fuzzing] content: $[http_get_content] ... ``` ## Combined literals and references You can combine literals and data references within a single expression. When doing so, the task manager will evaluate the entire expression and produce a set of values based on all possible combinations of the literals and referenced data. Lets take the folllowing example: - Combining a literal with a data reference `prefix_$[some_key]_suffix` If `some_key` contains the values `["some", "string"]`, the resulting set of values would be: - `prefix_some_suffix` - `prefix_string_suffix` ## Functions In addition to simple data references, the expression syntax supports a variety of built-in functions that can be used to manipulate and transform data. Here are some examples: - `capitalize($[some_data_key])`: Converts a string to uppercase. - `split($[some_data_key], ".")`: Splits a string into an array based on the specified delimiter (in this case, a period). ## Examples Here are some more complex examples that combine literals, data references and functions: ### Retrieving and transforming data This expression first splits the string retrieved from `some_data_key` at each period, takes the first element of the resulting array, and then capitalizes it. `capitalize(split($[some_data_key], ".")[0])` ### Accessing nested data This expression retrieves an array of values from the `property` field of each object in the `array_key` array within the `some_data_key` data structure. `$[some_data_key.array_key[*].property]` We also support the alternative wildcard JSONPath syntax: `$[some_data_key.array_key.*.property]` ### Applying templates to task definitions Here's an example of how these expressions might be used in the context of a TCP SYN task generator: ```yaml - kind: tcp_syn_generator ipv4_address_ranges: "$[target_ip_ranges]" ports: - 80 - "$[https_ports]" ``` The result of rendering this template would be a set of TCP SYN tasks, each with a specific IP address from the `target_ip_ranges` data reference and a destination port that is either 80 or one of the ports specified in the `https_ports` data reference. Note that templates are always converted to a set of values, even if the result is a single value. This ensures consistency in how task arguments are handled. In the example above, if `https_ports` contains the values `[443, 8443]`, the resulting task ports variable would contain a single set of integers `{80, 443, 8443}`. ipv4_address_ranges would contain a set of all individual IP addresses derived from the provided ranges. For ease of use, parameters for task templates can be either a single expression or an array of expressions. If an array is provided, the results of each expression will be combined into a single set. For example, the following two configurations are equivalent: ```yaml ipv4_address_ranges: "$[target_ip_ranges]" ``` ```yaml ipv4_address_ranges: - "$[target_ip_ranges]" ``` ### HTTP fuzzing examples When performing HTTP fuzzing, we'd often want to iterate over several different data variables (e.g. wordlists) and generate all possible combinations. Lets see how this can be achieved with our templating engine. Lets say we have the following data stored in our task collection: | key | kind | value | |----------------|---------------|---------------------------------------| | path_traversal | array of strings | ["../", ".;/"] | | paths | array of strings | ["/admin", "/login"] | | hostnames | array of strings | ["localhost", "127.0.0.1"] | We can then define an HTTP fuzzing task template like this: ```yaml - kind: http_fuzz method: GET host: target.com sni: target.com body: | GET /$[path_traversal]/$[paths] HTTP/1.1 Host: $[hostnames] ``` When this template is rendered, the task manager will generate a set of HTTP fuzzing tasks that cover all combinations of the provided path traversal strings, paths and hostnames. This would generate $2 (path\_traversal) * 2 (paths) * 2 (hostnames) = 8$ unique body payloads: ```http GET /../admin HTTP/1.1 Host: localhost ``` ```http GET /.;/admin HTTP/1.1 Host: localhost ``` ```http GET /../login HTTP/1.1 Host: localhost ``` ```http GET /.;/login HTTP/1.1 Host: localhost ``` ```http GET /../admin HTTP/1.1 Host: 127.0.0.1 ``` ```http GET /.;/admin HTTP/1.1 Host: 127.0.0.1 ``` ```http GET /../login HTTP/1.1 Host: 127.0.0.1 ``` ```http GET /.;/login HTTP/1.1 Host: 127.0.0.1 ``` Note that we could have taken this further by also templating the `host` and `sni` fields, but this should give you an idea of how powerful the templating engine can be when combined with data stored in the task collection. ``` ### `book/src/task_manager/workers.md` **Absolute path:** `/root/mrpf/book/src/task_manager/workers.md` ```markdown # Workers The task manager offers a few different types of workers, internal, external and bare metal workers. Depending on the type of task, it will be dispatched to a dedicated queue for the type of worker required. ## External Workers (AWS Lambda) These workers run on AWS Lambda outside of the VPC. They do not have direct access to the PostgreSQL database. They are typically used for tasks that communicate to external services for recon purposes. Unfortunately AWS Lambda does not support raw socket operations, so tasks that require raw socket access cannot be run on these workers. Examples of such tasks are TCP SYN scanning or custom TLS scanning. Useful for creating distributed tasks towards services like crt.sh, Censys, Shodan, or other public services. NOTE: they do NOT have access to the PostgreSQL database. This means storing task results are handled by passing SucceededWithData messages back to the Task Manager. ## Internal Workers (AWS VPC bound Lambda) These workers run on AWS Lambda inside of the VPC. They have read-write access to the tasks_data and recon tables in the PostgreSQL database. They are typically used for tasks that perform data mutations or filtering operations. Since they run inside of the VPC, they can connect directly to the PostgreSQL database. ## Bare Metal Workers Bare metal workers have raw socket access. This means they can use our custom network scanning engines like TcpSyn scanning and custom TLS scanning. There is a dedicated SQS queue for bare metal workers. On launch, the worker will poll SQS for new tasks to execute. Using environment variables, the worker can configure a timeout value to poll for messages on the SQS queue. If there are no messages within the given timeout, the worker will shutdown itself. This feature is useful for running the worker on EC2 to cut down costs. When running on EC2, the workers are placed in the VPC containing the PostgreSQL database. This means the bare metal workers have read-write access to the tasks_data and recon tables in the PostgreSQL database. This makes them suitable for both scanning tasks and heavy duty data mutation tasks. When running outside of AWS, the bare metal workers could use SucceededWithData messages to pass data back to the Task Manager. However, this is not implemented yet. It would require us to create either separate task definitions OR have a flag in the task definition to indicate if the worker is running inside or outside of AWS. ## Ideas and Future Work - Implement workers running on other environments like Azure, or bare metal hardware outside of AWS. ``` ### `clients/mrpf_api_client/src/base.rs` **Absolute path:** `/root/mrpf/clients/mrpf_api_client/src/base.rs` ```rust pub struct MpfApiClient { base_url: String, api_key: String, } impl MpfApiClient { pub fn new(base_url: &str, api_key: &str) -> MpfApiClient { MpfApiClient { base_url: base_url.to_string(), api_key: api_key.to_string(), } } pub fn get_base_url(&self) -> &str { &self.base_url } pub fn get_api_key(&self) -> &str { &self.api_key } } ``` ### `clients/mrpf_api_client/src/domains.rs` **Absolute path:** `/root/mrpf/clients/mrpf_api_client/src/domains.rs` ```rust use super::base::MpfApiClient; use mrpf_core::recon::Domain; impl MpfApiClient { /// List all domains for given target id. /// /// # Arguments /// /// * `target_id` - A string slice that holds the target id. /// /// # Returns /// /// A vector of strings that holds the list of domains. pub fn list_domains(&self, _target_id: &str) -> Vec { // vec![Domain::new( // "example.com".to_string(), // target_id.to_string(), // Some(false), // vec!["www.example.com".to_string()], // vec!["192.0.2.1".to_string()], // .into_iter() // .collect(), // )] vec![] } } ``` ### `clients/mrpf_api_client/src/lib.rs` **Absolute path:** `/root/mrpf/clients/mrpf_api_client/src/lib.rs` ```rust mod base; mod domains; pub use base::MpfApiClient; ``` ### `clients/mrpf_app/README.md` **Absolute path:** `/root/mrpf/clients/mrpf_app/README.md` ```markdown # Guys just wanna have fun Fuck my ambition, I don't need. - bounties - results - found bugs I just want to have fun and build out a nice fancy iOS app. It's going to have flashy graphs, a great client side overengineered client for MPF API, statistics from my Infrastructure on Azure and AWS, AR graph view of ips, domains, urls and javascript files to show off visionPro, mpf recipes customizable through the iOS app, universal app with widgets, app intents, shortcuts, macos, ios, ipados and watchos support, macos tables, tabview on iOS, two and threeway splitviews on larget screens and watchos, push notifications, triggering jobs and lambdas, animated main logo on startscreen, metal shaders just because I can, an integrated adventure game about a blues playing hacker and a whole bunch more. - Maak een module om sequence diagrams te tekenen! gebruik swiftUI shapes hiervoor en neem inspiratie van mermaid.js voor hoe je definitie doet. Maar misschien kan ik het verbeteren door yaml te gebruiken. Ik zou het heel graag veel mooier afgestemd willen hebben op HTTP verkeer zodat je mooi geformatteerde json payloads kan meegeven en headers die je kunt uitklappen ## Sequence Diagrams I want to be able to easily create beautiful sequence diagrams from within my iOS app. I want to be able to store the data on my backend API so I would need to be able to convert it. Lets start with a yaml based model: ```yaml SequenceDiagram: Actors: - name: Alice icon: person - name: ClientLibrary icon: code - name: Server icon: server Steps: - from: Alice to: Server arrow: -->> message: "Hi ClientLibrary" payload: | Authentication: Bearer { "id": None } - from: Server to: Alice arrow: <<-- message: "HTTP 200" Payload: | { "id": 1 } ``` ``` ### `clients/mrpf_cli/README.md` **Absolute path:** `/root/mrpf/clients/mrpf_cli/README.md` ```markdown # MRPF CLI ## How to run cargo build && sudo ../../target/debug/mrpf ## Things to do to make it more usable [Good thing to read](https://rust-cli.github.io/book/in-depth/human-communication.html) - Add command line arguments. - Will need sub modules so that i can easily add new features - clap is very popular it seems - Allow for configuration files - For instance, I need to provide router ip address. - would be great if i can use both some text file like toml AND use environment variables - [config](https://crates.io/crates/config) seems to be a good crate for this ``` ### `clients/mrpf_cli/src/main.rs` **Absolute path:** `/root/mrpf/clients/mrpf_cli/src/main.rs` ```rust // cargo build && sudo target/debug/mrpf mod progress; use mrpf_dns_resolver::DnsResolver; use mrpf_engine::ScannerConfig; use mrpf_engine::network::dns::DnsQueryType; use mrpf_engine::utils::{get_interface, write_json_file}; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use mrpf_http1_scanner::{Http1Scanner, Template}; use mrpf_tcp_syn_scanner::TcpSynScanner; use mrpf_tls_scanner::TlsScanner; use progress::{TcpSynProgressHandlerCLI, TlsProgressHandlerCLI}; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{self, BufRead}; use std::net::Ipv4Addr; use std::path::Path; fn main() { setup_tracing(); let interface = get_interface(Some("en0")); let src_mac = interface.mac.expect("No MAC address found"); // Home IPs let src_ip = Ipv4Addr::new(192, 168, 2, 244); // Spoofed IP let router_ip = Ipv4Addr::new(192, 168, 2, 254); // iConnect IPs // let src_ip = Ipv4Addr::new(10, 51, 62, 249); // Spoofed IP // let router_ip = Ipv4Addr::new(10, 51, 60, 1); let config = ScannerConfig::builder() .interface(interface) .src_mac(src_mac) .src_ip(src_ip) .router_ip(router_ip) .build(); // dns_scan(config.clone()); // http1_scan(config.clone()); tls_scan(config.clone()); // tcp_syn_scan(config.clone()); } /// Read list of IP addresses from a file and convert them to Ipv4Addresses fn read_ip_addresses_from_file(file_path: &str) -> Ipv4AddressRanges { let path = Path::new(file_path); let file = File::open(path).unwrap(); let reader = io::BufReader::new(file); let mut ip_addresses = Vec::new(); for line in reader.lines() { let line = line.unwrap(); ip_addresses.push(line); } ip_addresses.into() } // Entrypoint for HTTPS1 scanner #[expect(dead_code)] fn http1_scan(config: ScannerConfig) { let scanner = Http1Scanner::new(config); // let ipv4_addresses = read_ip_addresses_from_file("mrpf_cli/ip_ranges/x.com.txt"); // let ipv4_addresses = read_ip_addresses_from_file("mrpf_cli/ip_ranges/example.com.txt"); // let ipv4_addresses = read_ip_addresses_from_file("mrpf_cli/ip_ranges/mortimer.nl.txt"); let ipv4_addresses = vec![ // "136.144.153.226", // "93.184.215.14" "95.101.74.119", // Watson ] .into(); let ports: PortRanges = "443".into(); println!("Running http1 scan..."); let template: Template = b"GET / HTTP/1.1\r\nHost: ${host}\r\nConnection: close\r\n\r\n".to_vec(); // let variables = HashMap::new(); let variables: HashMap, Vec>> = [( b"${host}".to_vec(), vec![ b"sit.api.apac.aswatson.com".to_vec(), b"www2.aswatson.com".to_vec(), b"api.apac.admin.aswatson.com".to_vec(), b"api.apac.aswatson.com".to_vec(), b"api.apac.devportal.aswatson.com".to_vec(), b"api.parknshop.com".to_vec(), b"api.pns.hk".to_vec(), b"api.watsons.co.id".to_vec(), b"api.watsons.co.th".to_vec(), b"api.watsons.com.hk".to_vec(), b"api.watsons.com.my".to_vec(), b"api.watsons.com.ph".to_vec(), b"api.watsons.com.sg".to_vec(), b"api.watsons.com.tw".to_vec(), b"api.watsons.vn".to_vec(), b"cb.watsons.com.my".to_vec(), b"cb.watsons.com.ph".to_vec(), b"digireceipt.watsons.co.th".to_vec(), b"digireceipt.watsons.com.ph".to_vec(), b"digireceipt.watsons.com.tw".to_vec(), b"digireceipt.watsonswine.com".to_vec(), b"dms.mbhk.aswatson.net".to_vec(), b"dms.pnshk.aswatson.net".to_vec(), b"dms.wtchk.aswatson.net".to_vec(), b"dms.wtcsg.aswatson.net".to_vec(), b"elecboy.com.hk".to_vec(), b"elecboyprd.fortress.com.hk".to_vec(), b"evhk.aswatson.com".to_vec(), b"getactive.watsons.co.th".to_vec(), b"getactive.watsons.com.hk".to_vec(), b"getactive.watsons.com.my".to_vec(), b"getactive.watsons.com.tw".to_vec(), b"getactive.watsonsworld.com".to_vec(), b"hriq.watsons.com.sg".to_vec(), b"hriquat.watsons.com.sg".to_vec(), b"infwtctwy6u3t.aswatson.net".to_vec(), b"mapi.app.aswatson.com".to_vec(), b"mapiuat.app.aswatson.com".to_vec(), b"mcms-api.app.aswatson.com".to_vec(), b"mcms.app.aswatson.com".to_vec(), b"mcmsuat-api.app.aswatson.com".to_vec(), b"mcmsuat.app.aswatson.com".to_vec(), b"mmedia.app.aswatson.com".to_vec(), b"mmediauat.app.aswatson.com".to_vec(), b"mpos.watsons.co.th".to_vec(), b"mpos.watsons.com.sg".to_vec(), b"mshare.app.aswatson.com".to_vec(), b"mshare.watsons.com.my".to_vec(), b"mshareuat.app.aswatson.com".to_vec(), b"naturalsbywatsons.com".to_vec(), b"nestdashboard.aswatson.com".to_vec(), b"pns.hk".to_vec(), b"preprod.api.admin.aswatson.com".to_vec(), b"preprod.api.apac.aswatson.com".to_vec(), b"preprod.api.aswatson.com".to_vec(), b"preprod.api.devportal.aswatson.com".to_vec(), b"sit.api.apac.admin.aswatson.com".to_vec(), b"sit.api.apac.aswatson.com".to_vec(), b"sit.api.apac.devportal.aswatson.com".to_vec(), b"smile.parknshop.com".to_vec(), b"watson.aswatson.com".to_vec(), b"www.elecboy.com.hk".to_vec(), b"www.naturalsbywatsons.com".to_vec(), b"www.pns.hk".to_vec(), b"www1.elecboy.com.hk".to_vec(), b"www10.parknshop.com".to_vec(), b"www10.pns.hk".to_vec(), b"www10.watsons.co.id".to_vec(), b"www10.watsons.com.hk".to_vec(), b"www10.watsons.com.my".to_vec(), b"www10.watsons.com.ph".to_vec(), b"www10.watsons.com.sg".to_vec(), b"www10.watsons.com.tw".to_vec(), b"www10.watsons.vn".to_vec(), b"www1010010.parknshop.com".to_vec(), b"www20.watsons.co.th".to_vec(), ], )] .into(); let results = scanner.run( ipv4_addresses, ports, vec![ "sit.api.apac.aswatson.com".to_string(), // "www.mortimer.nl".to_string(), ], template, variables, ); for (request, response) in &results { println!( "\n\n\n{:?}\n{:?}", String::from_utf8_lossy(request.body()), response ); } // println!("Writing results to http1_scan.json"); write_json_file("http1_scan.json", &results); } // Entrypoint for resolving DNS #[allow(dead_code)] fn dns_scan(config: ScannerConfig) { let scanner = DnsResolver::new(config); let domain_names = HashSet::from([ "www.example.com".to_string(), "www.google.com".to_string(), "www.nu.nl".to_string(), ]); println!("Running DNS resolver scan A Records..."); let results = scanner.run(domain_names, DnsQueryType::A); if !results.is_empty() { let filename = "dns_resolver_a.json"; println!( "Resolved {} domain names, writing to {}", results.len(), filename ); write_json_file(filename, &results); } else { println!("No domain names were resolved"); } // println!("Running DNS resolver scan, PTR records..."); // let ips = vec![ // "136.144.153.226".to_string(), // mortimer.nl VPS // "2.18.121.138".to_string(), // Akamai // ]; // let results = scanner.run(ips, DnsQueryType::PTR); // if !results.is_empty() { // let filename = "dns_resolver_ptr.json"; // println!( // "Resolved {} domain names, writing to {}", // results.len(), // filename // ); // write_json_file(filename, &results); // } else { // println!("No domain names were resolved"); // } } // Entrypoint for scanning TLS #[allow(dead_code)] fn tls_scan(config: ScannerConfig) { let scanner = TlsScanner::new(config); // let ipv4_addresses = read_ip_addresses_from_file("mrpf_cli/ip_ranges/x.com.txt"); // let ipv4_addresses = read_ip_addresses_from_file("mrpf_cli/ip_ranges/eu-west-1_ec2.txt"); // let ipv4_addresses = // read_ip_addresses_from_file("/Users/thiezn/Development/mrpf/mrpf_cli/ip_ranges/ing.txt"); let ipv4_addresses = vec![ "136.144.153.226", // my vps // "142.251.36.14", // google.com // "2.19.195.88", // nu.com behind akamai. NOTE: AKAMAI DOESN'T SEEM TO WORK WITH OUR SCANNER. IS IT THE CLIENT HELLO? ] .into(); let ports: PortRanges = "443".into(); let targets = vec![( // vec!["x.com".to_string()], // vec!["cogitor.security.aws.rabo.cloud".to_string()], // vec!["example.com".to_string()], // vec!["*".to_string()], vec![ "mail.mortimer.nl".to_string(), "google.com".to_string(), "www.nu.nl".to_string(), ], ipv4_addresses, ports, )]; println!("Running tls scan..."); let progress_handler = TlsProgressHandlerCLI::new(&targets); let results = scanner.run_with_handle(targets, progress_handler); // let results = scanner.run(targets); // This better shows errors // let results = scanner.run(targets); let filename = "sni_scan.json"; println!("Writing {} results to {}", results.len(), filename); write_json_file(filename, &results); } /// Main entry point for the application. /// /// This will eventually move to a separate binary crate. #[expect(dead_code)] fn tcp_syn_scan(config: ScannerConfig) { let ipv4_addresses = read_ip_addresses_from_file("mrpf_cli/ip_ranges/x.com.txt"); // let ports: Ports = "443".into(); // let ports: Ports = "1-1024,8080-8081,10443,20443,11443".into(); let ports: PortRanges = "80,443,8080-8081".into(); // let ports: Ports = "1-65535".into(); let scanner = TcpSynScanner::new(config); let progress_handler = TcpSynProgressHandlerCLI::new(ipv4_addresses.len(), ports.len()); println!("Running TCP Syn scan..."); let open_ports = scanner.run_with_handle(ipv4_addresses, ports, progress_handler); println!( "Found {:} hosts with one or more open ports", open_ports.len() ); println!("Writing open ports to open_ports.json"); write_json_file("open_ports.json", &open_ports); } /// Setup tracing with separate stdout and stderr layers. /// /// Honor RUST_LOG if set, default to info. fn setup_tracing() { tracing_subscriber::fmt() // disabling time is handy because CloudWatch will add the ingestion time. .without_time() // remove the name of the function from every log entry .with_target(false) // disable ANSI color codes in stdout // .with_ansi(false) .init(); } ``` ### `clients/mrpf_cli/src/progress/dns.rs` **Absolute path:** `/root/mrpf/clients/mrpf_cli/src/progress/dns.rs` ```rust ``` ### `clients/mrpf_cli/src/progress/mod.rs` **Absolute path:** `/root/mrpf/clients/mrpf_cli/src/progress/mod.rs` ```rust mod dns; mod tcpsyn; mod tls; pub(crate) use tcpsyn::TcpSynProgressHandlerCLI; pub(crate) use tls::TlsProgressHandlerCLI; ``` ### `clients/mrpf_cli/src/progress/tcpsyn.rs` **Absolute path:** `/root/mrpf/clients/mrpf_cli/src/progress/tcpsyn.rs` ```rust use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use mrpf_engine::ProgressHandler; use mrpf_tcp_syn_scanner::TcpSynMessage; use std::collections::{HashMap, HashSet}; use std::net::Ipv4Addr; use std::sync::mpsc::Receiver; use std::time::Duration; const UPDATE_INTERVAL: Duration = Duration::from_secs(1); /// CLI ProgressHandler for the TCP SYN Scanner. /// /// The handler will ensure the scanner is stopped and the results are collected. A progress /// bar will be displayed in the CLI to show the progress of the scan. #[derive(Debug)] pub struct TcpSynProgressHandlerCLI { total_targets: u64, } impl TcpSynProgressHandlerCLI { pub fn new(ip_address_count: u32, port_count: u16) -> Self { Self { total_targets: (ip_address_count * port_count as u32) as u64, } } } impl ProgressHandler for TcpSynProgressHandlerCLI { type Message = TcpSynMessage; type Results = HashMap>; fn handler(&mut self, receiver: Receiver) -> Self::Results { // Variables holding statistics let mut results: HashMap> = HashMap::with_capacity(32); let rate_limit: Option = None; // not sure why/what? let retries = 0; let mut total_packets_received: u64 = 0; let mut total_rx_buffer_full: u64 = 0; let mut total_syn_packets_sent: u64 = 0; let mut total_packets_sent: u64 = 0; let mut total_tx_buffer_full: u64 = 0; let mut open_ports_found = 0; // Prepare our CLI progress indicator let multi_pb = MultiProgress::new(); let found_progress = multi_pb.add(ProgressBar::new_spinner()); found_progress.set_message("0"); found_progress.set_style( ProgressStyle::default_spinner() .template("{spinner:.yellow} Open ports: {msg:.green}") .expect("invalid template"), ); found_progress.enable_steady_tick(UPDATE_INTERVAL); let buffer_spinner = multi_pb.add(ProgressBar::new_spinner()); buffer_spinner.set_message("0/0"); buffer_spinner.set_style( ProgressStyle::default_spinner() .template("{spinner:.yellow} Tx/Rx Buffer full: {msg:.red}") .expect("invalid template"), ); let total_packets_spinner = multi_pb.add(ProgressBar::new_spinner()); total_packets_spinner.set_message("0/0"); total_packets_spinner.set_style( ProgressStyle::default_spinner() .template("{spinner:.yellow} Tx/Rx Packets: {msg:.blue}") .expect("invalid template"), ); total_packets_spinner.enable_steady_tick(UPDATE_INTERVAL); let total_scan_progress_bar = multi_pb.add(ProgressBar::new(self.total_targets)); total_scan_progress_bar.set_message("0"); total_scan_progress_bar.set_style( ProgressStyle::default_bar() .template(get_progress_stylization(&rate_limit, retries)) .expect("invalid template"), ); total_scan_progress_bar.enable_steady_tick(UPDATE_INTERVAL); let start_time = std::time::Instant::now(); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { eprintln!("Failed to receive status update: {}", e); break; } }; match message { TcpSynMessage::Result((ip, port)) => { let ip_entry = results.entry(ip).or_default(); if ip_entry.insert(port) { open_ports_found += 1; } } TcpSynMessage::RxStats { packet_count, full_read_buffer_count, } => { total_packets_received += packet_count as u64; total_rx_buffer_full += full_read_buffer_count as u64; } TcpSynMessage::TxStats { syn_packet_count, packet_count, full_write_buffer_count, } => { total_syn_packets_sent += syn_packet_count as u64; total_packets_sent += packet_count as u64; total_tx_buffer_full += full_write_buffer_count as u64; } TcpSynMessage::Finished => { break; } } // Update the progress bar found_progress.set_message(format!("{}", open_ports_found)); buffer_spinner .set_message(format!("{}/{}", total_tx_buffer_full, total_rx_buffer_full)); total_packets_spinner .set_message(format!("{}/{}", total_packets_sent, total_packets_received)); total_scan_progress_bar.set_position(total_syn_packets_sent / (1 + retries as u64)); } let end_time = std::time::Instant::now(); // Finish the progress bar found_progress.finish(); buffer_spinner.finish(); total_packets_spinner.finish(); total_scan_progress_bar.finish_and_clear(); let duration = (end_time - start_time).as_secs(); if duration < 120 { println!("\n\nScan completed in {} seconds", duration); } else { let minutes = duration / 60; println!("\n\nScan completed in {} minutes", minutes); } results } } fn get_progress_stylization(rate_limit: &Option, retries: u8) -> &'static str { match (rate_limit, retries) { (None, 0) => { "[ETA: {eta_precise:.dim}] {wide_bar:.blue} {pos:.dim}/{len:.dim} ports (Elapsed: {elapsed_precise:.dim})" } _ => { "[ETA: {eta_precise:.dim}] {wide_bar:.red} {pos:.dim}/{len:.dim} ports (Elapsed: {elapsed_precise:.dim})" } } } ``` ### `clients/mrpf_cli/src/progress/tls.rs` **Absolute path:** `/root/mrpf/clients/mrpf_cli/src/progress/tls.rs` ```rust use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use mrpf_engine::ProgressHandler; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use mrpf_tls_scanner::TlsMessage; use std::collections::{HashMap, HashSet}; use std::net::Ipv4Addr; use std::sync::mpsc::Receiver; use std::time::Duration; const UPDATE_INTERVAL: Duration = Duration::from_secs(1); /// CLI ProgressHandler for the TCP SYN Scanner. /// /// The handler will ensure the scanner is stopped and the results are collected. A progress /// bar will be displayed in the CLI to show the progress of the scan. #[derive(Debug)] pub struct TlsProgressHandlerCLI { total_targets: u64, } impl TlsProgressHandlerCLI { /// Create a new CLI ProgressHandler for the TCP SYN Scanner. pub fn new(targets: &[(Vec, Ipv4AddressRanges, PortRanges)]) -> Self { let total_targets = targets .iter() .map(|(names, ips, ports)| names.len() as u64 * ips.len() as u64 * ports.len() as u64) .sum(); TlsProgressHandlerCLI { total_targets } } } impl ProgressHandler for TlsProgressHandlerCLI { type Message = TlsMessage; type Results = HashMap>; fn handler(&mut self, receiver: Receiver) -> Self::Results { let mut results = HashMap::>::with_capacity(16); // Variables holding statistics let rate_limit: Option = None; // not sure why/what? let retries = 0; let mut total_packets_received: u64 = 0; let mut total_rx_buffer_full: u64 = 0; let mut total_syn_packets_sent: u64 = 0; let mut total_packets_sent: u64 = 0; let mut total_tx_buffer_full: u64 = 0; let mut names_found = 0; // Prepare our CLI progress indicator let multi_pb = MultiProgress::new(); let found_progress = multi_pb.add(ProgressBar::new_spinner()); found_progress.set_message("0"); found_progress.set_style( ProgressStyle::default_spinner() .template("{spinner:.yellow} Names found: {msg:.green}") .expect("invalid template"), ); found_progress.enable_steady_tick(UPDATE_INTERVAL); let buffer_spinner = multi_pb.add(ProgressBar::new_spinner()); buffer_spinner.set_message("0/0"); buffer_spinner.set_style( ProgressStyle::default_spinner() .template("{spinner:.yellow} Tx/Rx Buffer full: {msg:.red}") .expect("invalid template"), ); let total_packets_spinner = multi_pb.add(ProgressBar::new_spinner()); total_packets_spinner.set_message("0/0"); total_packets_spinner.set_style( ProgressStyle::default_spinner() .template("{spinner:.yellow} Tx/Rx Packets: {msg:.blue}") .expect("invalid template"), ); total_packets_spinner.enable_steady_tick(UPDATE_INTERVAL); let total_scan_progress_bar = multi_pb.add(ProgressBar::new(self.total_targets)); total_scan_progress_bar.set_message("0"); total_scan_progress_bar.set_style( ProgressStyle::default_bar() .template(get_progress_stylization(&rate_limit, retries)) .expect("invalid template"), ); total_scan_progress_bar.enable_steady_tick(UPDATE_INTERVAL); let start_time = std::time::Instant::now(); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { eprintln!("Failed to receive status update: {}", e); break; } }; match message { TlsMessage::Result((names, ip, port)) => { for name in names { let entry = results.entry(name.clone()).or_default(); entry.insert((ip, port)); if entry.len() == 1 { names_found += 1; } } } TlsMessage::RxStats { packet_count, full_read_buffer_count, } => { total_packets_received += packet_count as u64; total_rx_buffer_full += full_read_buffer_count as u64; } TlsMessage::TxStats { syn_packet_count, packet_count, full_write_buffer_count, } => { total_syn_packets_sent += syn_packet_count as u64; total_packets_sent += packet_count as u64; total_tx_buffer_full += full_write_buffer_count as u64; } TlsMessage::Finished => { break; } } // Update the progress bar found_progress.set_message(format!("{}", names_found)); buffer_spinner .set_message(format!("{}/{}", total_tx_buffer_full, total_rx_buffer_full)); total_packets_spinner .set_message(format!("{}/{}", total_packets_sent, total_packets_received)); total_scan_progress_bar.set_position(total_syn_packets_sent / (1 + retries as u64)); } let end_time = std::time::Instant::now(); // Finish the progress bar found_progress.finish(); buffer_spinner.finish(); total_packets_spinner.finish(); total_scan_progress_bar.finish_and_clear(); let duration = (end_time - start_time).as_secs(); if duration < 120 { println!("\n\nScan completed in {} seconds", duration); } else { let minutes = duration / 60; println!("\n\nScan completed in {} minutes", minutes); } results } } fn get_progress_stylization(rate_limit: &Option, retries: u8) -> &'static str { match (rate_limit, retries) { (None, 0) => { "[ETA: {eta_precise:.dim}] {wide_bar:.blue} {pos:.dim}/{len:.dim} ports (Elapsed: {elapsed_precise:.dim})" } _ => { "[ETA: {eta_precise:.dim}] {wide_bar:.red} {pos:.dim}/{len:.dim} ports (Elapsed: {elapsed_precise:.dim})" } } } ``` ### `infrastructure/api/README.md` **Absolute path:** `/root/mrpf/infrastructure/api/README.md` ```markdown # MRPF API The MRPF API holds all data related to Bug Bounty hunting. ## Spice Deployment framework Usage - Install Spice deployment framework package by running ```python3 -m pip install spicedf``` - To review a deployment run ```spice review ``` - To deploy deployment stages run ```spice deploy ...`````` ### `infrastructure/api/src/api/src/error.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/error.rs` ```rust use aws_lambda_events::apigw::ApiGatewayProxyResponse; use aws_lambda_events::encodings::Body; use lambda_runtime::Error as LambdaError; use lambda_runtime::tracing::{error, warn}; use mrpf_core::Error as DbError; use std::{error::Error as StdError, fmt}; // For the Result alias, don't do: // ```type Result = std::result::Result;``` // But do: // ```type Result = std::result::Result;``` // // Explanation [here](https://users.rust-lang.org/t/dont-do-type-result-t-std-result-t-myerror/49421) // // A further improvement is to add an empty tuple to the result T, this again saves a few characters. // It might not be super obvious to people working with my code, but this is my preference so acceptable // for this private codebase. /// Result type alias for mrpf_core library /// /// - The result type defaults to an empty tuple /// - The error type is an explicit enum with conversion methods `crate::Error`. /// /// Error types: /// * `Database` - Database errors /// * `Json` - JSON parsing errors /// * `Other` - Other errors pub type Result = std::result::Result; /// Central API error type. #[derive(Debug)] pub enum Error { // External system / dependency errors Database(DbError), Json(serde_json::Error), Lambda(Box), // Request / protocol level ExpectedBody, RouteNotFound(String), MethodNotAllowed(String), BadRequest(String), Serialization(String), // Infra / config MissingEnv(String), // Generic catch-all (keep last) Other(String), } impl fmt::Display for Error { /// Formats the error for display fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Error::Database(e) => write!(f, "Database error: {e}"), Error::Json(e) => write!(f, "JSON error: {e}"), Error::Lambda(e) => write!(f, "Lambda runtime error: {e}"), Error::ExpectedBody => write!(f, "Expected request body"), Error::RouteNotFound(p) => write!(f, "Route not found: {p}"), Error::MethodNotAllowed(m) => write!(f, "Method not allowed: {m}"), Error::BadRequest(msg) => write!(f, "Bad request: {msg}"), Error::Serialization(ctx) => write!(f, "Serialization error: {ctx}"), Error::MissingEnv(k) => write!(f, "Missing environment variable: {k}"), Error::Other(msg) => write!(f, "{msg}"), } } } impl StdError for Error { /// Conform to std::error::Error trait fn source(&self) -> Option<&(dyn StdError + 'static)> { match self { Error::Database(e) => Some(e), Error::Json(e) => Some(e), Error::Lambda(_) => None, // already boxed generic error, not exposing source Error::ExpectedBody | Error::RouteNotFound(_) | Error::MethodNotAllowed(_) | Error::BadRequest(_) | Error::Serialization(_) | Error::MissingEnv(_) | Error::Other(_) => None, } } } // Focused conversions so `?` just works: impl From for Error { fn from(e: DbError) -> Self { Error::Database(e) } } impl From for Error { fn from(e: serde_json::Error) -> Self { Error::Json(e) } } impl From for Error { fn from(e: LambdaError) -> Self { Error::Lambda(Box::new(e)) } } impl From<&str> for Error { fn from(s: &str) -> Self { Error::Other(s.into()) } } impl From<()> for Error { fn from(_: ()) -> Self { Error::Other("unit error".into()) } } impl From for Error { fn from(s: String) -> Self { Error::Other(s) } } // Reversed conversions, so we can wrap other error types into our Error type impl From for DbError { fn from(e: Error) -> Self { match e { Error::Database(db_err) => db_err, other => DbError::Other(format!("Converted from mrpf_api::Error: {other}")), } } } impl Error { /// Convert into an ApiGatewayProxyResponse logging internal details first. /// Return the original error message (Display) inside a JSON body: {"error": }. pub fn to_api_gw_response(&self) -> ApiGatewayProxyResponse { match self { Error::BadRequest(_) | Error::ExpectedBody | Error::RouteNotFound(_) | Error::MethodNotAllowed(_) => { warn!(err = ?self, "Client error"); } _ => { error!(err = ?self, "Server error"); } } let status = match self { Error::Database(_) | Error::Lambda(_) | Error::MissingEnv(_) | Error::Other(_) => 500, Error::Json(_) | Error::Serialization(_) => 400, Error::ExpectedBody => 400, Error::RouteNotFound(_) => 404, Error::MethodNotAllowed(_) => 405, Error::BadRequest(_) => 400, }; let body_text = serde_json::json!({"error": self.to_string()}).to_string(); // TODO the clean thing would be to set the content-type header to application/json here. ApiGatewayProxyResponse { status_code: status.into(), body: Some(Body::Text(body_text)), ..Default::default() } } } ``` ### `infrastructure/api/src/api/src/handlers/apn_devices.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/apn_devices.rs` ```rust use crate::error::{Error, Result}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_helpers::{get_multi_strings, get_opt_string, get_path_string}; use mrpf_core::DbClient; use mrpf_core::recon::{ApnDevice, ApnDeviceEnv, ApnDeviceKind}; /// Handle all /iam/apn_devices* routes pub(crate) async fn handle_apn_devices_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/apn_devices" => { let ids = get_multi_strings( &request.multi_value_query_string_parameters, "apn_device_id", ); let kind = get_opt_string(&request.query_string_parameters, "kind").and_then(|k| match k .as_str() { "iphone" => Some(ApnDeviceKind::iPhone), "ipad" => Some(ApnDeviceKind::iPad), "macbook" => Some(ApnDeviceKind::MacBook), _ => None, }); let env = get_opt_string(&request.query_string_parameters, "env").and_then(|e| { match e.as_str() { "development" => Some(ApnDeviceEnv::Development), "production" => Some(ApnDeviceEnv::Production), _ => None, } }); let devices = db.list_apn_devices(ids, kind, env).await?; let body = serde_json::to_string(&devices) .map_err(|e| Error::Serialization(format!("apn_devices: {e:?}")))?; Ok(create_response(200, Some(&body))) } "/apn_devices/{apn_device_id}" => { let id = get_path_string(&request.path_parameters, "apn_device_id")?; let device = db.get_apn_device(&id).await?; let body = serde_json::to_string(&device) .map_err(|e| Error::Serialization(format!("apn_device: {e:?}")))?; Ok(create_response(200, Some(&body))) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::PUT => { let body_value = parse_body_to_json(request.body.clone())?; match path { "/apn_devices" => { let devices: Vec = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("ApnDevices JSON: {e:?}")))?; db.put_apn_devices(devices).await?; Ok(create_response(204, None)) } "/apn_devices/{apn_device_id}" => { let id = get_path_string(&request.path_parameters, "apn_device_id")?; let device: ApnDevice = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("ApnDevice JSON: {e:?}")))?; if device.id != id { Err(Error::BadRequest("Mismatched apn_device_id".into())) } else { db.put_apn_device(&device).await?; Ok(create_response(204, None)) } } _ => Err(Error::RouteNotFound(path.to_string())), } } HttpMethod::DELETE => match path { "/apn_devices" => { let ids = get_multi_strings( &request.multi_value_query_string_parameters, "apn_device_id", ); if ids.is_empty() { return Err(Error::BadRequest( "Missing at least one apn_device_id query parameter".into(), )); } db.delete_apn_devices(ids).await?; Ok(create_response(204, None)) } "/apn_devices/{apn_device_id}" => { let id = get_path_string(&request.path_parameters, "apn_device_id")?; db.delete_apn_device(&id).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), }, _ => Err(Error::MethodNotAllowed("".into())), } } ``` ### `infrastructure/api/src/api/src/handlers/domains.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/domains.rs` ```rust use crate::error::{Error, Result}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_runtime::tracing::debug; use lambda_helpers::{get_multi_strings, get_opt_bool, get_opt_string, get_path_string}; use mrpf_core::DbClient; use mrpf_core::recon::Domain; use std::collections::HashSet; /// Handle all /iam/domains* routes pub(crate) async fn handle_domains_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/domains/search" => { let pattern = get_opt_string(&request.query_string_parameters, "pattern") .ok_or_else(|| Error::BadRequest("pattern required".into()))?; let deep = get_opt_bool(&request.query_string_parameters, "deep") .map_err(Error::BadRequest)? .unwrap_or(false); debug!(pattern = %pattern, deep = deep, "Searching FQDNs"); let fqdns = db.search_fqdns(&pattern, deep).await?; let body = serde_json::to_string(&fqdns) .map_err(|e| Error::Serialization(format!("search: {e:?}")))?; Ok(create_response(200, Some(&body))) } "/domains/{fqdn}/history" => { let fqdn = get_path_string(&request.path_parameters, "fqdn")?; let history = db.list_fqdn_history(&fqdn).await?; let body = history .into_iter() .map(|(datetime, cnames, ip_addresses)| { serde_json::json!({ "datetime": datetime, "cnames": cnames, "ip_addresses": ip_addresses, }) }) .collect::>(); let body = serde_json::to_string(&body) .map_err(|e| Error::Serialization(format!("history: {e:?}")))?; Ok(create_response(200, Some(&body))) } "/domains" => { Err(Error::BadRequest( "Use /iam/targets/{target_id}/domains to list domains. I think for now getting all domains could put a bit too much pressure on the database.".into(), )) } _ => Err(Error::RouteNotFound(path.to_string())) }, HttpMethod::PUT => { let body_value = parse_body_to_json(request.body.clone())?; match path { "/domains/{fqdn}" => { let fqdn = get_path_string(&request.path_parameters, "fqdn")?; let domain: Domain = serde_json::from_value(body_value.clone()) .map_err(|e| Error::BadRequest(format!("Domain JSON: {e:?}")))?; if domain.fqdn != fqdn { return Err(Error::BadRequest("Mismatched fqdn".into())); } db.put_domains(HashSet::from_iter(vec![domain])).await?; Ok(create_response(204, None)) } "/domains" => { let domains = serde_json::from_value(body_value.clone()) .map_err(|e| Error::BadRequest(format!("Domains JSON: {e:?}")))?; db.put_domains(domains).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())) } } HttpMethod::DELETE => match path { "/domains" => { let fqdns = get_multi_strings(&request.multi_value_query_string_parameters, "fqdn"); if fqdns.is_empty() { return Err(Error::BadRequest("Missing at least one fqdn query parameter".into())); } db.delete_domains(fqdns).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())) }, _ => Err(Error::MethodNotAllowed("".into())) } } ``` ### `infrastructure/api/src/api/src/handlers/ips.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/ips.rs` ```rust use crate::error::{Error, Result}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_helpers::{get_opt_string, get_path_string}; use mrpf_core::DbClient; use std::collections::HashSet; /// Handle all /iam/ips* routes pub(crate) async fn handle_ips_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/ips" => { let target_id = get_opt_string(&request.query_string_parameters, "target_id"); let ips = db.list_ip_addresses(target_id).await?; Ok(create_response( 200, Some( &serde_json::to_string(&ips) .map_err(|e| Error::Serialization(format!("ips: {e:?}")))?, ), )) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::PUT => match path { "/ips" => { let body_value = parse_body_to_json(request.body.clone())?; let ips = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("invalid ips: {e}")))?; db.put_ip_addresses(ips).await?; Ok(create_response(204, None)) } "/ips/{ip}/ports" => { let ip = get_path_string(&request.path_parameters, "ip")?; let body_value = parse_body_to_json(request.body.clone())?; let tcp = body_value.get("tcp").and_then(|v| v.as_array()).map(|a| { a.iter() .filter_map(|x| { x.as_u64().and_then(|n| { if n <= u16::MAX as u64 { Some(n as u16) } else { None } }) }) .collect::>() }); let udp = body_value.get("udp").and_then(|v| v.as_array()).map(|a| { a.iter() .filter_map(|x| { x.as_u64().and_then(|n| { if n <= u16::MAX as u64 { Some(n as u16) } else { None } }) }) .collect::>() }); db.update_ip_ports(&ip, tcp, udp).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::POST => { let ip = get_path_string(&request.path_parameters, "ip")?; let body_value = parse_body_to_json(request.body.clone())?; let arr = body_value .get("ports") .and_then(|v| v.as_array()) .ok_or_else(|| Error::BadRequest("ports array required".into()))?; let mut ports = HashSet::with_capacity(arr.len()); for v in arr { let n = v .as_u64() .ok_or_else(|| Error::BadRequest("port must be number".into()))?; if n > u16::MAX as u64 { return Err(Error::BadRequest("port out of range".into())); } ports.insert(n as u16); } match path { "/ips/{ip}/ports/tcp" => { db.add_tcp_ports(&ip, ports).await?; Ok(create_response(204, None)) } "/ips/{ip}/ports/udp" => { db.add_udp_ports(&ip, ports).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), } } _ => Err(Error::MethodNotAllowed("unsupported method".into())), } } ``` ### `infrastructure/api/src/api/src/handlers/jobs.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/jobs.rs` ```rust use crate::error::{Error, Result}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_helpers::{get_multi_uuids, get_path_uuid}; use mrpf_core::DbClient; use mrpf_core::recon::Job; /// Handle all /jobs* routes pub(crate) async fn handle_jobs_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/jobs" => { let job_ids = get_multi_uuids(&request.multi_value_query_string_parameters, "job_id")?; let jobs = db.list_jobs(job_ids).await?; Ok(create_response( 200, Some( &serde_json::to_string(&jobs) .map_err(|e| Error::Serialization(format!("jobs: {e:?}")))?, ), )) } "/jobs/{job_id}" => { let job_id = get_path_uuid(&request.path_parameters, "job_id")?; let job = db.get_job(&job_id).await?; let body = serde_json::to_string(&job) .map_err(|e| Error::Serialization(format!("job: {e:?}")))?; Ok(create_response(200, Some(&body))) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::PUT => { let body_value = parse_body_to_json(request.body.clone())?; match path { "/jobs/{job_id}" => { let job_id = get_path_uuid(&request.path_parameters, "job_id")?; let job: Job = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("TaskCollections JSON: {e:?}")))?; if job.id != job_id { return Err(Error::BadRequest( "TaskCollection ID in path does not match ID in body".into(), )); } let updated_job = db.put_job(job).await?; // TODO: Invoke the mrpf_task_scheduler lambda after each successful creation // The mrpf_task_scheduler will poll our database to find new jobs to schedule immediately. // This avoids waiting for schedules when immediate execution is desired. Ok(create_response( 200, Some(&serde_json::to_string(&updated_job).map_err(|e| { Error::Serialization(format!("job PUT response: {e:?}")) })?), )) } "/jobs" => { let jobs: Vec = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("TaskCollections JSON: {e:?}")))?; if jobs.is_empty() { return Err(Error::BadRequest( "At least one TaskCollection must be provided".into(), )); } let updated_jobs = db.put_jobs(jobs).await?; // TODO: Invoke the mrpf_task_scheduler lambda after each successful creation // The mrpf_task_scheduler will poll our database to find new jobs to schedule immediately. // This avoids waiting for schedules when immediate execution is desired. Ok(create_response( 200, Some(&serde_json::to_string(&updated_jobs).map_err(|e| { Error::Serialization(format!("jobs PUT response: {e:?}")) })?), )) } _ => Err(Error::RouteNotFound(path.to_string())), } } HttpMethod::DELETE => { let job_ids = get_multi_uuids(&request.multi_value_query_string_parameters, "job_id")?; if job_ids.is_empty() { return Err(Error::BadRequest( "Missing at least one job_id query parameter".into(), )); } db.delete_jobs(job_ids).await?; Ok(create_response(204, None)) } _ => Err(Error::MethodNotAllowed("".into())), } } ``` ### `infrastructure/api/src/api/src/handlers/mod.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/mod.rs` ```rust mod apn_devices; mod domains; mod ips; mod jobs; mod notifications; mod targets; mod template_keys; mod word_categories; mod words; pub(crate) use apn_devices::handle_apn_devices_route; pub(crate) use domains::handle_domains_route; pub(crate) use ips::handle_ips_route; pub(crate) use jobs::handle_jobs_route; pub(crate) use notifications::handle_notifications_route; pub(crate) use targets::handle_targets_route; pub(crate) use template_keys::handle_template_keys_route; pub(crate) use word_categories::handle_word_categories_route; pub(crate) use words::handle_words_route; ``` ### `infrastructure/api/src/api/src/handlers/notifications.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/notifications.rs` ```rust use crate::error::{Error, Result}; use crate::utils::{ApnProvider, send_apn_notifications}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_helpers::{ get_multi_strings, get_multi_uuids, get_opt_bool, get_opt_i64, get_opt_string, get_path_uuid, }; use mrpf_core::DbClient; use mrpf_core::recon::{Notification, NotificationCategory}; use uuid::Uuid; /// Handle all /iam/notifications* routes pub(crate) async fn handle_notifications_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, apn_provider: &ApnProvider, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/notifications" => { let ids = get_multi_uuids( &request.multi_value_query_string_parameters, "notification_id", )?; let page = get_opt_i64(&request.query_string_parameters, "page") .map_err(Error::BadRequest)?; let records_per_page = get_opt_i64(&request.query_string_parameters, "records_per_page") .map_err(Error::BadRequest)?; // let sort = get_opt_string(&request.query_string_parameters, "sort"); let category = get_opt_string(&request.query_string_parameters, "category") .and_then(|c| match c.as_str() { "generic" => Some(NotificationCategory::Generic), "task_collection_result" => { Some(NotificationCategory::TaskCollectionResult) } "alert" => Some(NotificationCategory::Alert), _ => None, }); let is_read = get_opt_bool(&request.query_string_parameters, "is_read") .map_err(Error::BadRequest)?; let target_id = get_opt_string(&request.query_string_parameters, "target_id"); let notifications = db .list_notifications(ids, category, is_read, target_id, page, records_per_page) .await?; let body = serde_json::to_string(¬ifications) .map_err(|e| Error::Serialization(format!("notifications: {e:?}")))?; Ok(create_response(200, Some(&body))) } "/notifications/{notification_id}" => { let id = get_path_uuid(&request.path_parameters, "notification_id")?; let notification = db.get_notification(&id).await?; let body = serde_json::to_string(¬ification) .map_err(|e| Error::Serialization(format!("notification: {e:?}")))?; Ok(create_response(200, Some(&body))) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::PUT => { let body_value = parse_body_to_json(request.body.clone())?; match path { "/notifications" => { // Create new notifications let notifications: Vec = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("Notifications JSON: {e:?}")))?; db.put_notifications(¬ifications).await?; // Send notifications to all registered APN devices send_apn_notifications(db, apn_provider, notifications).await?; Ok(create_response(204, None)) } "/notifications/{notification_id}/read" => { let id = get_path_uuid(&request.path_parameters, "notification_id")?; db.mark_notification_read(&id).await?; Ok(create_response(204, None)) } "/notifications/{notification_id}/unread" => { let id = get_path_uuid(&request.path_parameters, "notification_id")?; db.mark_notification_unread(&id).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), } } HttpMethod::DELETE => match path { "/notifications" => { let ids = get_multi_strings( &request.multi_value_query_string_parameters, "notification_id", ) .into_iter() .filter_map(|s| Uuid::parse_str(&s).ok()) .collect::>(); if ids.is_empty() { return Err(Error::BadRequest( "Missing at least one notification_id query parameter".into(), )); } db.delete_notifications(ids).await?; Ok(create_response(204, None)) } "/notifications/{notification_id}" => { let id = get_path_uuid(&request.path_parameters, "notification_id")?; db.delete_notification(&id).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), }, _ => Err(Error::MethodNotAllowed("".into())), } } ``` ### `infrastructure/api/src/api/src/handlers/targets.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/targets.rs` ```rust use crate::error::{Error, Result}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_helpers::{get_multi_strings, get_opt_bool, get_path_string}; use mrpf_core::DbClient; use mrpf_core::recon::{SoaRecord, Target}; use std::collections::HashSet; /// Handle all /targets* routes pub(crate) async fn handle_targets_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/targets" => { let target_ids = get_multi_strings(&request.multi_value_query_string_parameters, "target_id"); let targets = db.list_targets(target_ids).await?; Ok(create_response( 200, Some( &serde_json::to_string(&targets) .map_err(|e| Error::Serialization(format!("targets: {e:?}")))?, ), )) } "/targets/{target_id}" => { let target_id = get_path_string(&request.path_parameters, "target_id")?; let target = db.get_target(&target_id).await?; let body = serde_json::to_string(&target) .map_err(|e| Error::Serialization(format!("target: {e:?}")))?; Ok(create_response(200, Some(&body))) } "/targets/{target_id}/domains" => { let target_id = get_path_string(&request.path_parameters, "target_id")?; let is_active = get_opt_bool(&request.query_string_parameters, "is_active") .map_err(Error::BadRequest)?; let is_wildcard = get_opt_bool(&request.query_string_parameters, "is_wildcard") .map_err(Error::BadRequest)?; let is_dangling = get_opt_bool(&request.query_string_parameters, "is_dangling") .map_err(Error::BadRequest)?; let domains = db .list_domains(target_id, is_active, is_wildcard, is_dangling) .await?; let body = serde_json::to_string(&domains) .map_err(|e| Error::Serialization(format!("domains: {e:?}")))?; Ok(create_response(200, Some(&body))) } "/targets/{target_id}/ips" => { let target_id = get_path_string(&request.path_parameters, "target_id")?; let ips = db.list_ip_addresses(Some(target_id)).await?; let body = serde_json::to_string(&ips) .map_err(|e| Error::Serialization(format!("ips: {e:?}")))?; Ok(create_response(200, Some(&body))) } "/targets/{target_id}/soa" => { let target_id = get_path_string(&request.path_parameters, "target_id")?; let soa_records = db.list_soa_records(&target_id).await?; let body = serde_json::to_string(&soa_records) .map_err(|e| Error::Serialization(format!("soa_records: {e:?}")))?; Ok(create_response(200, Some(&body))) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::DELETE => { let target_ids = get_multi_strings(&request.multi_value_query_string_parameters, "target_id"); if target_ids.is_empty() { return Err(Error::BadRequest( "Missing at least one target_id query parameter".into(), )); } db.delete_targets(target_ids).await?; Ok(create_response(204, None)) } HttpMethod::PUT => match path { "/targets" => { let body_value = parse_body_to_json(request.body.clone())?; let targets: Vec = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("Targets JSON: {e:?}")))?; db.put_targets(targets).await?; Ok(create_response(204, None)) } "/targets/{target_id}" => { let target_id = request.path_parameters.get("target_id").cloned().unwrap(); let body_value = parse_body_to_json(request.body.clone())?; let target: Target = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("Target JSON: {e:?}")))?; if target.id != target_id { Err(Error::BadRequest("Mismatched target_id".into())) } else { db.put_target(&target).await?; Ok(create_response(204, None)) } } "/targets/{target_id}/favorite" => { let target_id = get_path_string(&request.path_parameters, "target_id")?; db.mark_target_favorite(&target_id).await?; Ok(create_response(204, None)) } "/targets/{target_id}/unfavorite" => { let target_id = get_path_string(&request.path_parameters, "target_id")?; db.mark_target_unfavorite(&target_id).await?; Ok(create_response(204, None)) } "/targets/{target_id}/soa" => { let target_id = request .path_parameters .get("target_id") .cloned() .ok_or_else(|| Error::BadRequest("Missing path parameter: target_id".into()))?; let body_value = parse_body_to_json(request.body.clone())?; let soa_records_vec: HashSet = serde_json::from_value(body_value) .map_err(|e| Error::BadRequest(format!("SOA Records JSON: {e:?}")))?; let soa_records: HashSet = soa_records_vec.into_iter().collect(); db.set_target_soa_records(&target_id, soa_records).await?; Ok(create_response(204, None)) } "/targets/{target_id}/words" => { let target_id = request .path_parameters .get("target_id") .cloned() .ok_or_else(|| Error::BadRequest("Missing path parameter: target_id".into()))?; let word_names = get_multi_strings(&request.multi_value_query_string_parameters, "word_name"); db.delete_words_by_name_and_target(word_names, Some(target_id)) .await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), }, _ => Err(Error::MethodNotAllowed("".into())), } } ``` ### `infrastructure/api/src/api/src/handlers/template_keys.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/template_keys.rs` ```rust use crate::error::{Error, Result}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_helpers::{get_multi_strings, get_multi_uuids}; use lambda_runtime::tracing::debug; use mrpf_core::DbClient; use mrpf_core::recon::TemplateKey; /// Handle all /template_keys* routes pub(crate) async fn handle_template_keys_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/template_keys" => { let template_key_ids = get_multi_uuids( &request.multi_value_query_string_parameters, "template_key_id", )?; let template_key_names: Vec = get_multi_strings( &request.multi_value_query_string_parameters, "template_key_name", ); debug!(template_key_id = ?template_key_ids, template_key_names = ?template_key_names, "Listing template keys"); let wordlists = db .list_template_keys(template_key_ids, template_key_names) .await?; let body = serde_json::to_string(&wordlists) .map_err(|e| Error::Serialization(format!("search: {e:?}")))?; Ok(create_response(200, Some(&body))) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::PUT => { let body_value = parse_body_to_json(request.body.clone())?; match path { "/template_keys" => { let template_keys: Vec = serde_json::from_value(body_value.clone()) .map_err(|e| Error::BadRequest(format!("TemplateKey JSON: {e:?}")))?; if template_keys.is_empty() { return Err(Error::BadRequest("No template_keys provided".into())); } db.put_template_keys(template_keys).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), } } HttpMethod::DELETE => match path { "/template_keys" => { let template_key_ids = get_multi_uuids( &request.multi_value_query_string_parameters, "template_key_id", )?; db.delete_template_keys(template_key_ids).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), }, _ => Err(Error::MethodNotAllowed("".into())), } } ``` ### `infrastructure/api/src/api/src/handlers/word_categories.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/word_categories.rs` ```rust use crate::error::{Error, Result}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_helpers::{get_multi_strings, get_multi_uuids}; use lambda_runtime::tracing::debug; use mrpf_core::DbClient; use mrpf_core::recon::WordCategory; /// Handle all /word_categories* routes pub(crate) async fn handle_word_categories_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/word_categories" => { let category_ids = get_multi_uuids(&request.multi_value_query_string_parameters, "category_id")?; let category_names = get_multi_strings( &request.multi_value_query_string_parameters, "category_name", ); debug!(category_id = ?category_ids, category_names = ?category_names, "Listing categories"); let wordlists = db.list_categories(category_ids, category_names).await?; let body = serde_json::to_string(&wordlists) .map_err(|e| Error::Serialization(format!("search: {e:?}")))?; Ok(create_response(200, Some(&body))) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::PUT => { let body_value = parse_body_to_json(request.body.clone())?; match path { "/word_categories" => { let word_categories: Vec = serde_json::from_value(body_value.clone()) .map_err(|e| Error::BadRequest(format!("WordCategory JSON: {e:?}")))?; if word_categories.is_empty() { return Err(Error::BadRequest("No word_categories provided".into())); } db.put_word_categories(word_categories).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), } } HttpMethod::DELETE => match path { "/word_categories" => { let category_ids = get_multi_uuids(&request.multi_value_query_string_parameters, "category_id")?; db.delete_word_categories(category_ids).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), }, _ => Err(Error::MethodNotAllowed("".into())), } } ``` ### `infrastructure/api/src/api/src/handlers/words.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/handlers/words.rs` ```rust use crate::error::{Error, Result}; use crate::{create_response, parse_body_to_json}; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use aws_lambda_events::http::Method as HttpMethod; use lambda_helpers::{get_multi_strings, get_multi_uuids, get_opt_string, get_path_uuid}; use lambda_runtime::tracing::debug; use mrpf_core::DbClient; use mrpf_core::recon::Word; /// Handle all /words* routes pub(crate) async fn handle_words_route( path: &str, request: &ApiGatewayProxyRequest, db: &DbClient, ) -> Result { match request.http_method { HttpMethod::GET => match path { "/words" => { let target_id: Option = get_opt_string(&request.query_string_parameters, "target_id"); let included_category_names = get_multi_strings( &request.multi_value_query_string_parameters, "included_category_name", ); let excluded_category_names: Vec = get_multi_strings( &request.multi_value_query_string_parameters, "excluded_category_name", ); debug!(target_id = ?target_id, included_categories = ?included_category_names, excluded_categories = ?excluded_category_names, "Generating wordlist"); let wordlists = db .generate_wordlist(included_category_names, excluded_category_names, target_id) .await?; let body = serde_json::to_string(&wordlists) .map_err(|e| Error::Serialization(format!("search: {e:?}")))?; Ok(create_response(200, Some(&body))) } "/words/{word_id}" => { let word_id = get_path_uuid(&request.path_parameters, "word_id")?; debug!(word_id = %word_id, "Searching word IDs"); let word = db.get_word(word_id).await?; let body = serde_json::to_string(&word) .map_err(|e| Error::Serialization(format!("search: {e:?}")))?; Ok(create_response(200, Some(&body))) } _ => Err(Error::RouteNotFound(path.to_string())), }, HttpMethod::PUT => { let body_value = parse_body_to_json(request.body.clone())?; match path { "/words" => { let words: Vec = serde_json::from_value(body_value.clone()) .map_err(|e| Error::BadRequest(format!("Word JSON: {e:?}")))?; if words.is_empty() { return Err(Error::BadRequest("No words provided".into())); } db.put_words(words).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), } } HttpMethod::DELETE => match path { "/words" => { let word_ids = get_multi_uuids(&request.multi_value_query_string_parameters, "word_id")?; if word_ids.is_empty() { return Err(Error::BadRequest( "Missing at least one word_id query parameter".into(), )); } db.delete_words(word_ids).await?; Ok(create_response(204, None)) } _ => Err(Error::RouteNotFound(path.to_string())), }, _ => Err(Error::MethodNotAllowed("".into())), } } ``` ### `infrastructure/api/src/api/src/main.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/main.rs` ```rust mod error; mod handlers; mod utils; use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse}; use error::{Error, Result}; use handlers::{ handle_apn_devices_route, handle_domains_route, handle_ips_route, handle_jobs_route, handle_notifications_route, handle_targets_route, handle_template_keys_route, handle_word_categories_route, handle_words_route, }; use lambda_helpers::{get_db_connection_details, init_tracing}; use lambda_runtime::Error as LambdaError; use lambda_runtime::tracing::{debug, error, info}; use lambda_runtime::{LambdaEvent, run, service_fn}; use mrpf_core::DbClient; use utils::ApnProvider; /// Main handler function for the Lambda. async fn function_handler( event: LambdaEvent, db: &DbClient, apn_provider: &ApnProvider, ) -> Result { debug!(event = ?event, "Received API Gateway Proxy Request"); match route_handler(event.payload, db, apn_provider).await { Ok(response) => Ok(response), Err(err) => Ok(err.to_api_gw_response()), } } /// Route incoming requests to the appropriate handler async fn route_handler( request: ApiGatewayProxyRequest, db: &DbClient, apn_provider: &ApnProvider, ) -> Result { // NOTE: Path has the real route used by the client, but we're interested // in the resource pattern, eg. /iam/targets/{target_id}/domains // We strip the /iam or /web prefix so we can more easily match paths. let raw_path = request .resource .as_deref() .ok_or_else(|| Error::BadRequest("Could not parse path from incoming request".into()))?; // Strip leading "/iam" or "/web" prefixes so routing can match against the resource pattern. let path = raw_path .strip_prefix("/iam") .or_else(|| raw_path.strip_prefix("/web")) .unwrap_or(raw_path); info!("Routing request for path: {}", path); if path.starts_with("/targets") { handle_targets_route(path, &request, db).await } else if path.starts_with("/domains") { handle_domains_route(path, &request, db).await } else if path.starts_with("/ips") { handle_ips_route(path, &request, db).await } else if path.starts_with("/words") { handle_words_route(path, &request, db).await } else if path.starts_with("/word_categories") { handle_word_categories_route(path, &request, db).await } else if path.starts_with("/template_keys") { handle_template_keys_route(path, &request, db).await } else if path.starts_with("/jobs") { handle_jobs_route(path, &request, db).await } else if path.starts_with("/notifications") { handle_notifications_route(path, &request, db, apn_provider).await } else if path.starts_with("/apn_devices") { handle_apn_devices_route(path, &request, db).await } else { Err(Error::RouteNotFound(path.to_string())) } } /// Convert incoming body to a serde_json::Value regardless of encoding fn parse_body_to_json(body: Option) -> Result { match body { Some(s) => serde_json::from_str::(&s).map_err(|e| { error!(error = ?e, "Failed to parse request body as JSON"); Error::from(e) }), None => Err(Error::ExpectedBody), } } /// Create HTTP response with given status code and body fn create_response(status_code: u16, body: Option<&str>) -> ApiGatewayProxyResponse { let body_opt = body.map(|b| aws_lambda_events::encodings::Body::Text(b.to_string())); ApiGatewayProxyResponse { status_code: status_code.into(), body: body_opt, ..Default::default() } } /// Main entrypoint for the Lambda. /// /// This code will be run only during INIT phase meaning variables instantiated here /// will potentially be re-used across invocations. /// /// Make sure to provide more detailed error results here as the default CloudWatch /// Logs only print out Runtime.Error without any context. The cleanest way is probably /// to use map_err()? on every fallible operation and log the error before returning it. /// /// More information [here](https://docs.aws.amazon.com/lambda/latest/dg/rust-handler.html#rust-shared-state) #[tokio::main] async fn main() -> Result<(), LambdaError> { init_tracing(); // Load AWS config and create Secrets Manager client let config = aws_config::load_from_env().await; // Initialize RDS PostgreSQL client and SQS client and invoke our lambda debug!("Composing database connection details"); let (host, port, database_name, username, password) = get_db_connection_details(&config).await?; debug!("Initializing database client"); let db = DbClient::new( &host, port, &database_name, &username, &password, "api_lambda", "recon,util,public", ) .await .map_err(|e| { error!("Failed initialize database client: {:?}", e); e })?; debug!("Initializing APN Provider"); // Note that we generate a token here on startup so lambdas will likely have a cached token // across invocations. However, we might need to generate a new token during invocation, which // requires &mut ApnProvider. This means the re-generated token will not be cached across invocations. // Since this will likely be a rare occurrence, we accept the performance hit. let mut apn_provider = ApnProvider::from_aws_secrets(&config).await?; let _ = apn_provider.generate_token()?; debug!("Starting Lambda runtime"); run(service_fn(|event: LambdaEvent| { let db = &db; let apn_provider = &apn_provider; async move { function_handler(event, db, apn_provider).await } })) .await?; Ok(()) } ``` ### `infrastructure/api/src/api/src/utils/apn_notifications.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/utils/apn_notifications.rs` ```rust use crate::error::{Error, Result}; use crate::utils::ApnProvider; use lambda_runtime::tracing::{debug, error, info, warn}; use mrpf_core::DbClient; use mrpf_core::recon::{ApnDeviceEnv, Notification}; /// Send one or more notifications to all registered APN devices. /// /// NOTE: In a future refactor the `ApnTokenProvider` should be created during Lambda /// cold start and reused across invocations (e.g. static Mutex). For now it's per-call /// which still caches within the call if multiple token requests happen. pub async fn send_apn_notifications( db_client: &DbClient, apn_provider: &ApnProvider, notifications: Vec, ) -> Result<()> { if notifications.is_empty() { debug!("send_apn_notifications called with 0 notifications"); return Ok(()); } // Retrieve JWT token for APN, uses cached token if still valid let jwt = match apn_provider.get_cached_token() { Some(token) => { debug!("Using cached APN JWT token"); token } None => { warn!( "Generating new APN JWT token during lambda invocation. Will not be cached across other warm invocations." ); let mut apn_provider = apn_provider.clone(); apn_provider.generate_token()? } }; // Load devices (all) let devices = db_client .list_apn_devices(Vec::new(), None, None) .await .map_err(Error::from)?; if devices.is_empty() { debug!("No APN devices registered; skipping send"); return Ok(()); } // Prepare reqwest client let client = match reqwest::Client::builder().http2_prior_knowledge().build() { Ok(c) => c, Err(e) => { error!(error=?e, "Failed to build reqwest client for APN"); return Err(Error::Other("Reqwest client build failure".into())); } }; let mut unregistered: Vec = Vec::new(); let auth_header_value = format!("bearer {}", jwt); for notification in notifications.iter() { let collapse_id = notification .target_id .as_deref() .filter(|s| !s.is_empty()) .unwrap_or("generic"); let payload = notification.to_apn_json(); let body = match serde_json::to_vec(&payload) { Ok(b) => b, Err(e) => { warn!(error=?e, id=%notification.id, "Failed to serialize notification JSON"); continue; } }; for device in devices.iter() { let base = match device.env { ApnDeviceEnv::Development => apn_provider.dev_endpoint(), ApnDeviceEnv::Production => apn_provider.prod_endpoint(), }; let url = format!("{}{}", base, device.id); let request = client .post(&url) .header("authorization", &auth_header_value) .header("apns-expiration", "0") .header("apns-priority", "10") .header("apns-topic", apn_provider.bundle_id()) .header("apns-push-type", "alert") .header("apns-collapse-id", collapse_id) .header("content-type", "application/json") .body(body.clone()); match request.send().await { Ok(resp) => { let status = resp.status(); if status.is_success() { info!(device_id=%device.id, notification_id=%notification.id, status=%status.as_u16(), "APN notification sent"); } else if status.as_u16() == 410 { warn!(device_id=%device.id, "APN device returned 410 (unregistered)"); unregistered.push(device.id.clone()); } else { let text = resp.text().await.unwrap_or_default(); warn!(device_id=%device.id, status=%status.as_u16(), body=%text, "APN notification failed"); } } Err(e) => { warn!(error=?e, device_id=%device.id, "HTTP error sending APN notification"); } } } } if !unregistered.is_empty() { info!( count = unregistered.len(), "Removing unregistered APN devices" ); unregistered.sort_unstable(); unregistered.dedup(); if let Err(e) = db_client.delete_apn_devices(unregistered.clone()).await { warn!(error=?e, "Failed to delete some unregistered APN devices"); } } Ok(()) } ``` ### `infrastructure/api/src/api/src/utils/apn_token_provider.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/utils/apn_token_provider.rs` ```rust use crate::error::{Error, Result}; use chrono::Utc; use jsonwebtoken::{Algorithm, EncodingKey, Header}; use lambda_runtime::Error as LambdaError; use lambda_runtime::tracing::{debug, error}; use std::env; /// Cached APN JWT provider. /// /// Apple allows reusing the same JWT for up to 60 minutes. We keep a conservative /// lifetime (default 50 minutes) to avoid clock skew problems. Token generation /// is pure (no network roundtrip) so this mostly saves CPU work. #[derive(Clone)] pub struct ApnProvider { team_id: String, bundle_id: String, key_id: String, encoding_key: EncodingKey, cached: Option, lifetime_secs: u64, } #[derive(Clone)] struct CachedToken { token: String, iat: i64, } impl ApnProvider { /// Create a new provider from raw PEM key. pub fn new( team_id: String, bundle_id: String, key_id: String, pem: String, lifetime_secs: Option, ) -> Result { let encoding_key = EncodingKey::from_ec_pem(pem.as_bytes()).map_err(|e| { error!(error=?e, "Failed to parse APN ECDSA private key"); Error::Other("Invalid APN auth key".into()) })?; Ok(Self { team_id, bundle_id, key_id, encoding_key, cached: None, lifetime_secs: lifetime_secs.unwrap_or(50 * 60), // 50 minutes }) } pub fn dev_endpoint(&self) -> &str { "https://api.sandbox.push.apple.com/3/device/" } pub fn prod_endpoint(&self) -> &str { "https://api.push.apple.com/3/device/" } pub fn bundle_id(&self) -> &str { &self.bundle_id } /// Get the cached token string, if still valid. /// /// This will return None if the token is expired or never generated. This /// function is useful in Lambda invocations where we might have a shared /// state across invocations warm start. In this case we need to use this /// function and if None, we need to explicitly regenerate the token for /// each warm invocation after this. /// /// This should be a rare occurrence as the token is valid for 50 minutes /// and most warm Lambda invocations won't survive that long. pub fn get_cached_token(&self) -> Option { let now = Utc::now().timestamp(); if let Some(ref c) = self.cached && (now - c.iat) < self.lifetime_secs as i64 { return Some(c.token.clone()); } None } /// Get a JWT suitable for APNs authorization header. Uses cached value if still valid. pub fn generate_token(&mut self) -> Result { let now = Utc::now().timestamp(); if let Some(ref c) = self.cached && (now - c.iat) < self.lifetime_secs as i64 { return Ok(c.token.clone()); } #[derive(serde::Serialize)] struct Claims<'a> { iss: &'a str, iat: i64, } let iat = now; let claims = Claims { iss: &self.team_id, iat, }; let mut header = Header::new(Algorithm::ES256); header.kid = Some(self.key_id.clone()); let token_str = jsonwebtoken::encode(&header, &claims, &self.encoding_key).map_err(|e| { error!(error=?e, "Failed to encode APN JWT"); Error::Other("Failed to encode APN JWT".into()) })?; self.cached = Some(CachedToken { token: token_str, iat, }); // Safe unwrap: just assigned. Ok(self.cached.as_ref().unwrap().token.clone()) } /// Create a new provider by retrieving secrets from AWS Secrets Manager. pub async fn from_aws_secrets(config: &aws_config::SdkConfig) -> Result { let (team_id, bundle_id, key_id, pem) = get_apn_secret(config).await.map_err(|e| { error!(error=?e, "Failed to retrieve APN secret from AWS Secrets Manager"); Error::Other("Failed to retrieve APN secret".into()) })?; Self::new(team_id, bundle_id, key_id, pem, None) } } /// Retrieve APN secret from AWS Secrets Manager /// /// Expected json structure: /// ```json /// { /// "team_id": "DQ6H5PTWJ9", /// "bundle_id": "nl.mortimer.MyPensionFund", /// "key_id": ".....", /// "auth_token_signing_key": "-----BEGIN PRIVATE KEY-----\n....\n-----END PRIVATE KEY-----" /// } /// ``` pub async fn get_apn_secret( config: &aws_config::SdkConfig, ) -> Result<(String, String, String, String), LambdaError> { let secret_name = env::var("APN_SECRET")?; let client = aws_sdk_secretsmanager::Client::new(config); debug!("Created Secrets Manager client"); let secret_value = client .get_secret_value() .secret_id(secret_name) .send() .await?; debug!("Successfully retrieved secret string"); let secret_string = secret_value .secret_string() .ok_or_else(|| LambdaError::from("Secret is not a string"))?; let secrets: serde_json::Value = serde_json::from_str(secret_string)?; let team_id = secrets["team_id"] .as_str() .ok_or_else(|| LambdaError::from("Missing team_id"))? .to_owned(); let bundle_id = secrets["bundle_id"] .as_str() .ok_or_else(|| LambdaError::from("Missing bundle_id"))? .to_owned(); let key_id = secrets["key_id"] .as_str() .ok_or_else(|| LambdaError::from("Missing key_id"))? .to_owned(); let pem = secrets["auth_token_signing_key"] .as_str() .ok_or_else(|| LambdaError::from("Missing auth_token_signing_key (PEM)"))? .to_owned(); Ok((team_id, bundle_id, key_id, pem)) } ``` ### `infrastructure/api/src/api/src/utils/mod.rs` **Absolute path:** `/root/mrpf/infrastructure/api/src/api/src/utils/mod.rs` ```rust mod apn_notifications; mod apn_token_provider; pub use apn_notifications::send_apn_notifications; pub use apn_token_provider::ApnProvider; ``` ### `infrastructure/bare_metal_worker/OPTIMIZELINUX.md` **Absolute path:** `/root/mrpf/infrastructure/bare_metal_worker/OPTIMIZELINUX.md` ```markdown # When I'm ready, i should investigate what kind of OS level o[pzimiations i can do ChatGPT: Great question, Mathijs. Since your Rust binary will use **raw sockets** (with `cap_net_raw+ep`) and run on an **AL2023 minimal AMI** on EC2, you can push quite a bit more performance out of the kernel path with a mix of NIC/driver settings, network‑stack sysctls, CPU/IRQ affinity, and (optionally) busy‑polling or XDP/AF_XDP. Below is a practical, “pick‑what‑you‑need” checklist plus ready‑to‑run snippets. --- ## 0) First principles (EC2 + ENA) - Make sure the instance is using **ENA enhanced networking** and the **ENA driver** is up to date; ENA exposes multiple Rx/Tx queue pairs and MSI‑X vectors and supports XDP/AF_XDP on Nitro instances. [1](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/enhanced-networking-ena.html)[2](https://docs.kernel.org/networking/device_drivers/ethernet/amazon/ena.html) - Queue count is bounded by instance family/size. Many non‑network‑accelerated families (e.g., **c5**) cap at **8 queues** per ENI; network‑accelerated families (e.g., **c5n**) can expose up to **32**—relevant if you plan to scale with RPS/XDP. [3](https://github.com/amzn/amzn-drivers/issues/294) --- ## 1) Fastest user–kernel paths: consider AF_PACKET (TPACKETv3) or AF_XDP If you can slightly adapt the app interface: - **AF_PACKET with PACKET_MMAP / TPACKET_V3** gives a zero‑copy ring per socket in userspace and typically **~15–20% less CPU** and **~20% higher capture rate** vs plain raw sockets. Use `TPACKET_V3` and fanout if you parallelize. [4](https://www.kernel.org/doc/html/latest/networking/packet_mmap.html)[5](https://stackoverflow.com/questions/21200009/capturing-performance-with-pcap-vs-raw-socket) - **AF_XDP** (XDP → userspace rings) can bypass much of the stack and provide **zero‑copy** when the NIC/driver supports it (ENA does). It binds an XDP program to a queue and delivers frames directly into your UMEM ring. Best‑case latency/pps among kernel‑provided paths. [6](https://www.kernel.org/doc/html/latest/networking/af_xdp.html)[7](https://deepwiki.com/amzn/amzn-drivers/2.3-ena-advanced-features) - DPDK and the XDP project both provide AF_XDP examples and drivers; check kernel ≥4.18/5.4 for zero‑copy and need_wakeup support. [8](https://doc.dpdk.org/guides/nics/af_xdp.html) If you must stay with `AF_INET/SOCK_RAW`, keep reading—you can still get significant wins. --- ## 2) NIC/driver tuning (ENA) for raw sockets 1) **Right‑size Rx/Tx rings** to reduce drops under bursty load: ```bash # see current ethtool -g eth0 # enlarge (pick values your instance & ENA support) sudo ethtool -G eth0 rx 4096 tx 4096 ``` Increasing ring size reduces `rx_queue_*_drops` in `ethtool -S` when the kernel can’t drain fast enough. [9](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/8/html/monitoring_and_managing_system_status_and_performance/tuning-the-network-performance_monitoring-and-managing-system-status-and-performance) 2) **Interrupt coalescing**: coalesce a little to raise throughput (pps) without adding too much latency: ```bash # example: 128 µs Rx coalesce (adjust to your latency budget) sudo ethtool -C eth0 rx-usecs 128 ``` Coalescing reduces interrupt storms and can lift throughput; AWS uses similar guidance in low‑latency ingest scenarios. [10](https://docs.aws.amazon.com/ground-station/latest/gs-agent-ug/ec2-instance-performance-tuning.html) 3) **Offloads (GRO/LRO/TSO/GSO)** and raw sockets: If you need **packet‑exact** view (e.g., custom L3 parsing), **disable GRO/LRO** so the NIC/stack doesn’t merge frames; these offloads are great for TCP but can distort raw‑packet workloads. [11](https://docs.kernel.org/networking/segmentation-offloads.html)[12](https://my.f5.com/manage/s/article/K39505011) ```bash sudo ethtool -K eth0 gro off lro off gso off tso off ``` (You can re‑enable selectively later; persist via your ifcfg/systemd‑networkd profile.) [13](https://serverfault.com/questions/463111/how-to-persist-ethtool-settings-through-reboot) --- ## 3) Scale packet processing across CPUs: RSS + RPS/RFS + XPS - ENA provides **multi‑queue RSS** (hardware steering) that spreads flows across queue/IRQ pairs; confirm/adjust RxFH indirection if needed. [14](https://www.kernel.org/doc/Documentation/networking/scaling.txt) - Complement with **RPS/RFS** (software steering) to spread softirq work and improve **cache locality** towards the CPU running your thread(s). [15](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/6/html/performance_tuning_guide/network-rps)[16](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/6/html/performance_tuning_guide/network-rfs) - Optionally pin **XPS** (Tx steering) to keep egress on the same CPUs as ingress. [14](https://www.kernel.org/doc/Documentation/networking/scaling.txt) **Quick recipe** (assume `eth0` and N queues): ```bash # (a) Map each rx-queue to a CPU mask for RPS for q in /sys/class/net/eth0/queues/rx-*; do # example: all CPUs allowed; tailor hex mask to your CPU set echo ffffffff > $q/rps_cpus # (b) Enable per-queue flow count for RFS (divide total across queues) echo 4096 > $q/rps_flow_cnt done # (c) Set total sock flow table entries (power of two) echo 65536 | sudo tee /proc/sys/net/core/rps_sock_flow_entries ``` RPS spreads softirq processing; RFS steers flows to the CPU where the **consuming thread** runs for better cache hit rates. [15](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/6/html/performance_tuning_guide/network-rps)[16](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/6/html/performance_tuning_guide/network-rfs) --- ## 4) Make the kernel spend more time in your socket: NAPI busy‑polling Linux supports **busy‑polling** on sockets and in the core via: - `SO_BUSY_POLL` / `SO_PREFER_BUSY_POLL` socket options, and - sysctls `net.core.busy_read` and `net.core.busy_poll` (µs). [17](https://docs.kernel.org/networking/napi.html)[18](https://sysctl-explorer.net/net/core/busy_poll/) This reduces wakeups and tail latency by polling NAPI queues instead of waiting for interrupts. Trade‑off: increased CPU power draw. **Sysctl example** (start modestly and measure): ```bash sudo sysctl -w net.core.busy_poll=50 sudo sysctl -w net.core.busy_read=50 ``` If you later switch to **io_uring**, you can also register NAPI busy‑poll preferences in the ring (maps to `SO_PREFER_BUSY_POLL`). [19](https://www.man7.org/linux/man-pages/man3/io_uring_register_napi.3.html) In Rust, you can enable per‑socket busy‑poll: ```rust use std::os::fd::AsRawFd; use nix::sys::socket::setsockopt; use nix::sys::socket::sockopt::*; let fd = socket.as_raw_fd(); let usec: i32 = 50; setsockopt(fd, BusyPoll, &usec)?; // SO_BUSY_POLL let prefer: i32 = 1; setsockopt(fd, PreferBusyPoll, &prefer)?; // SO_PREFER_BUSY_POLL (if available) ``` (Use `libc` if your nix version lacks these constants.) --- ## 5) Prevent drops before the stack: backlog & instrumentation - Increase the per‑CPU **ingress backlog** to avoid drops when softirqs can’t keep up: ```bash sudo sysctl -w net.core.netdev_max_backlog=32768 ``` If `/proc/net/softnet_stat` column 2 increases, you’re overflowing the backlog—raise it further. [20](https://steemit.com/linux/@blake.heo/linux-net-core-netdevmaxbacklog-tuning) - Watch `ethtool -S eth0 | grep -E 'rx.*drop|rx_queue'` to see NIC‑level drops; if they grow, enlarge Rx rings (above). [9](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/8/html/monitoring_and_managing_system_status_and_performance/tuning-the-network-performance_monitoring-and-managing-system-status-and-performance) --- ## 6) CPU and IRQ affinity: keep hot paths on the right cores - **Pin NIC IRQs** (per‑queue MSI‑X interrupts) and softirqs to specific CPUs; align your process/thread CPU affinity to those same CPUs to minimize cross‑CPU bouncing. [21](https://www.kernel.org/doc/html/latest/networking/scaling.html)[22](https://documentation.ubuntu.com/real-time/rt-conf/how-to/tune-irq-affinity/) - If you already rely on **irqbalance**, you can still hand‑tune or even temporarily disable it for tighter control; AWS showed wins from manual pinning on some instance families. [23](https://aws.amazon.com/blogs/compute/optimizing-network-intensive-workloads-on-amazon-ec2-a1-instances/) **Quick check & pinning:** ```bash grep -i ena /proc/interrupts # find IRQs per queue (ens5/eth0) # Example: pin IRQ 76 to CPU2 (bitmask: 1<<2 = 0x4) echo 4 | sudo tee /proc/irq/76/smp_affinity ``` (You can script this for all Rx queues; there are helper scripts to distribute IRQs evenly.) [24](https://github.com/suominen/network-irq-affinity) --- ## 7) Give your binary priority: real‑time scheduling, CPU pinning, memory locking You can bias the scheduler and keep your threads resident: - Run the service with **SCHED_RR** or **SCHED_FIFO** and a high RT priority via **systemd** (no code changes): Add a service drop‑in: ```ini # /etc/systemd/system/yourapp.service.d/rt.conf [Service] CPUSchedulingPolicy=rr CPUSchedulingPriority=90 IOSchedulingClass=realtime IOSchedulingPriority=0 LimitRTPRIO=99 LimitMEMLOCK=infinity CPUAffinity=2 3 # pin to the CPUs you also used for IRQ/RPS ``` Then `systemctl daemon-reload && systemctl restart yourapp`. This uses the kernel RT policies (prio 1–99). [25](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux_for_real_time/8/html/optimizing_rhel_8_for_real_time_for_low_latency_operation/assembly_viewing-scheduling-priorities-of-running-threads_optimizing-rhel8-for-real-time-for-low-latency-operation)[26](https://www.man7.org/linux/man-pages/man7/sched.7.html) - If you set RT from within the app, make sure the service/user has `LimitRTPRIO` (or PAM limits) so `sched_setscheduler()` doesn’t fail with EPERM. [27](https://stackoverflow.com/questions/72593494/change-linux-thread-priority-to-real-time-sched-fifo) - Consider `mlockall(MCL_CURRENT|MCL_FUTURE)` to avoid page‑fault hiccups along the hot path (requires proper `LimitMEMLOCK` as above). --- ## 8) Socket/stack buffers and MTU - For high PPS with small frames, increase **rmem/wmem** maxima and defaults; for large payloads, consider **jumbo MTU** end‑to‑end to reduce per‑packet overhead (only if your path supports it). [9](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/8/html/monitoring_and_managing_system_status_and_performance/tuning-the-network-performance_monitoring-and-managing-system-status-and-performance) ```bash sudo sysctl -w net.core.rmem_max=33554432 net.core.wmem_max=33554432 sudo sysctl -w net.core.rmem_default=262144 net.core.wmem_default=262144 # MTU only if every hop supports it; verify before changing. # ip link set dev eth0 mtu 9001 ``` --- ## 9) Optional: step up to XDP/AF_XDP when you need more If you eventually outgrow raw sockets: - Attach a minimal XDP program on the ENA interface and **redirect** to AF_XDP sockets bound to specific queues; your Rust process reads frames from UMEM rings (zero copy). You’ll trade some complexity for very high pps and lower jitter. [6](https://www.kernel.org/doc/html/latest/networking/af_xdp.html)[28](https://github.com/xdp-project/xdp-tutorial/blob/master/advanced03-AF_XDP/README.org) - Note ENA queue limits by instance type (see §0) and that XDP may consume queue resources; plan cores/queues accordingly. [3](https://github.com/amzn/amzn-drivers/issues/294) --- ## 10) A turnkey “baseline” script you can bake into the AMI Here’s a conservative baseline (safe for most ENA‑backed instances). It applies ring/feature tweaks, enables RPS/RFS across all CPUs, raises backlogs, and sets modest busy‑polling. Adjust CPU masks/values for your box. ```bash #!/usr/bin/env bash set -euo pipefail IF="${1:-eth0}" # 1) NIC rings & moderate coalescing ethtool -G "$IF" rx 4096 tx 4096 || true ethtool -C "$IF" rx-usecs 128 || true # 2) Disable coalescing offloads that distort packet granularity for raw sockets ethtool -K "$IF" gro off lro off gso off tso off || true # 3) RPS/RFS: use all CPUs (compute full mask) NCPU=$(nproc) # Build hex mask with NCPU bits set MASK="" bits=$NCPU while (( bits > 0 )); do take=$(( bits > 32 ? 32 : bits )) MASK="$(printf '%x' $(( (1< "$q"/rps_cpus || true echo 4096 > "$q"/rps_flow_cnt || true done echo 65536 > /proc/sys/net/core/rps_sock_flow_entries || true # 4) Backlogs & buffers sysctl -w net.core.netdev_max_backlog=32768 \ net.core.rmem_default=262144 net.core.wmem_default=262144 \ net.core.rmem_max=33554432 net.core.wmem_max=33554432 >/dev/null # 5) Busy-poll (very modest) sysctl -w net.core.busy_poll=50 net.core.busy_read=50 >/dev/null ``` Why these choices: larger rings/backlog reduce drops under burst; RPS/RFS spreads CPU load and improves locality; offload disables give you packet‑true semantics; busy‑polling trims tail latency for poll/epoll loops. [9](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/8/html/monitoring_and_managing_system_status_and_performance/tuning-the-network-performance_monitoring-and-managing-system-status-and-performance)[15](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/6/html/performance_tuning_guide/network-rps)[16](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/6/html/performance_tuning_guide/network-rfs)[11](https://docs.kernel.org/networking/segmentation-offloads.html)[20](https://steemit.com/linux/@blake.heo/linux-net-core-netdevmaxbacklog-tuning)[17](https://docs.kernel.org/networking/napi.html) --- ## 11) Verifying improvements - **NIC stats**: `ethtool -S eth0 | egrep 'rx.*(drop|errors)|tx.*(drop|errors)|rx_queue'` (drops trending to zero is good). [9](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux/8/html/monitoring_and_managing_system_status_and_performance/tuning-the-network-performance_monitoring-and-managing-system-status-and-performance) - **Backlog overflows**: `/proc/net/softnet_stat` column 2 per CPU should stay at 0 during load. [20](https://steemit.com/linux/@blake.heo/linux-net-core-netdevmaxbacklog-tuning) - **Interrupt balance**: `watch -n1 cat /proc/interrupts` while generating traffic; IRQ lines should grow on the CPUs you pinned. [22](https://documentation.ubuntu.com/real-time/rt-conf/how-to/tune-irq-affinity/) - **Throughput/latency**: Use your app’s own counters; for synthetic tests, `pktgen`, `moongen`, or `trafgen` work well. --- ## Implementation help: systemd unit for RT + capabilities Your AMI can ship a service definition that grants capabilities and RT priority without requiring root: ```ini # /etc/systemd/system/yourapp.service [Unit] Description=Your raw-socket Rust app [Service] ExecStart=/opt/yourapp/yourapp-binary # Capabilities: CAP_NET_RAW already needed; add CAP_SYS_NICE to allow RT if set inside the app AmbientCapabilities=CAP_NET_RAW CAP_SYS_NICE CapabilityBoundingSet=CAP_NET_RAW CAP_SYS_NICE NoNewPrivileges=true # Prefer systemd to set RT for you (see §7) CPUSchedulingPolicy=rr CPUSchedulingPriority=90 LimitRTPRIO=99 LimitMEMLOCK=infinity CPUAffinity=2 3 Restart=on-failure [Install] WantedBy=multi-user.target ``` This uses Linux RT scheduling and affinity via systemd’s `CPUScheduling*` directives (range 1–99 for `fifo/rr`). [25](https://docs.redhat.com/en/documentation/red_hat_enterprise_linux_for_real_time/8/html/optimizing_rhel_8_for_real_time_for_low_latency_operation/assembly_viewing-scheduling-priorities-of-running-threads_optimizing-rhel8-for-real-time-for-low-latency-operation)[26](https://www.man7.org/linux/man-pages/man7/sched.7.html) --- ### Want me to tailor this to your target instance? If you tell me the **instance family/size** you plan to use and whether you’re open to **AF_PACKET/AF_XDP**, I can: - compute the exact **CPU masks** and per‑queue RPS/RFS numbers, - generate an **IRQ pinning** script for ENA’s queue/IRQ layout, - and give you a minimal **Rust snippet** for busy‑poll + `SO_RCVBUF`/`SO_SNDBUF` tuning specific to your packet sizes. Also, do you prefer **lowest latency** or **max PPS**? That will swing coalescing, busy‑poll numbers, and the RT settings. ``` ### `infrastructure/bare_metal_worker/README.md` **Absolute path:** `/root/mrpf/infrastructure/bare_metal_worker/README.md` ```markdown # Bare Metal Worker This builds a new lightweight EC2 AMI using AWS Image Builder that is optimized for running my MRPF scanners. The scanners require raw sockets which are not supported on Fargate or lambda. The AMI will be published to a storage account and get exported using CloudFormation outputs so that it can be used in other stacks. The task_manager stack will contain a SQS queue dedicated to managing the bare metal workers. When the task manager schedules a task for a bare metal worker, it will send a message to this queue. Either a Lambda function or a CloudWatch Metric Math will start up required number of bare metal workers based on the number of messages in the queue. Refer to the following link for [EC2 auto scaling using Metric Math]() ## How a bare metal worker works 1. When started, the worker will poll for a new, single task from the SQS queue. 2. When a task is received, the worker will process the message and execute the task. 3. When the task is complete, it will send the regular Task Manager updates back to the task manager SQS queue. 4. The worker will then try to process another task from the queue. If no task is available within a 5 minute period, the worker will shut itself down to save costs. ## Codebase To ensure we can run bare metal workers outside of AWS, eg. a Macbook, a VPS or a Raspberry Pi dropped in a datacenter, the worker logic is maintained in the `mrpf_bare_metal_worker` crate. The worker will run as a separate binary that can be built and deployed to the bare metal worker AMI. This repository is only responsible for building the most optimized AMI image and publishing it to S3. ``` ### `infrastructure/db_and_network/README.md` **Absolute path:** `/root/mrpf/infrastructure/db_and_network/README.md` ```markdown # MRPF Database and Network Infrastructure The MRPF Database and Network Infrastructure is responsible for managing the database and network resources for the MRPF application. The code in this infrastructure folder is being kept to a minimum. This will allow us to move to other infrastructure solutions without much hassle. deploy using ```sh cd infrastructure/db_and_network spice deploy Test ``` ## Spice Deployment framework Usage - Install Spice deployment framework package by running ```python3 -m pip install spicedf``` - To review a deployment run ```spice review ``` - To deploy deployment stages run ```spice deploy ...``` ``` ### `infrastructure/task_manager/README.md` **Absolute path:** `/root/mrpf/infrastructure/task_manager/README.md` ```markdown # MRPF Task Manager The MRPF Task Manager is responsible for running tasks related to Bug Bounty hunting. The code in this infrastructure folder is being kept to a minimum. This will allow us to move to other infrastructure solutions without much hassle. deploy using ```sh cd infrastructure/task_manager spice deploy Test --rust-support --arm64 ``` ## Spice Deployment framework Usage - Install Spice deployment framework package by running ```python3 -m pip install spicedf``` - To review a deployment run ```spice review ``` - To deploy deployment stages run ```spice deploy ...``` ``` ### `infrastructure/task_manager/src/manager/src/main.rs` **Absolute path:** `/root/mrpf/infrastructure/task_manager/src/manager/src/main.rs` ```rust use aws_lambda_events::event::sqs::SqsEvent; use lambda_helpers::{get_db_connection_details, init_tracing}; use lambda_runtime::tracing::{debug, error, info}; use lambda_runtime::{Error, LambdaEvent, run, service_fn}; use mrpf_core::DbClient; use mrpf_core::tasks::QueueClient; use mrpf_core::tasks::manager::Manager; use mrpf_core::tasks::models::TaskManagerQueueMessage; // TODO: Error Handling they recommend to implement From Diagnostics for your Error type // https://github.com/awslabs/aws-lambda-rust-runtime?tab=readme-ov-file#implement-your-own-diagnostic // https://docs.aws.amazon.com/lambda/latest/dg/rust-handler.html#rust-error // impl From for Diagnostic { // fn from(error: ErrorResponse) -> Diagnostic { // Diagnostic { // error_type: "MyErrorType".into(), // error_message: error.0.to_string(), // } // } // } /// Main handler function for the Lambda /// /// # Arguments /// /// * `event` - LambdaEvent - The event that triggered the Lambda /// * `db` - DbClient - The database client to interact with the Job Manager database /// * `queue` - QueueClient - The queue client to interact with SQS queues // #[tracing::instrument( // skip(sqs_event), // fields( // req_id = %sqs_event.context.request_id, // function_name = %sqs_event.context.function_name, // function_version = %sqs_event.context.function_version // ) // )] async fn function_handler( event: LambdaEvent, manager: &Manager, ) -> Result<(), Error> { let messages = match serde_json::from_value(event.payload) { Ok(sqs_event) => { debug!("Converting SQS record(s) to TaskManagerQueueMessages",); parse_records(LambdaEvent { payload: sqs_event, context: event.context, })? } Err(_) => { // Likely dealing with a scheduled event. info!("Event is not an SQS event, likely dealing with scheduled trigger"); Vec::new() } }; info!(message_count = %messages.len(),"Invoking manager"); match manager.run(messages).await { Ok(_) => { info!("Manager completed successfully"); } Err(e) => { error!("Manager encountered an error: {:?}", e); return Err(e.into()); } }; Ok(()) } /// Try to deserialize a LambdaEvent into a vector of TaskManagerQueueMessage. /// /// # Arguments /// * `event` - LambdaEvent - The event that triggered the Lambda /// /// # Returns /// * `Result, serde_json::Error>` - A result containing a vector of TaskManagerQueueMessage or a serde_json::Error pub fn parse_records( event: LambdaEvent, ) -> Result, serde_json::Error> { event .payload .records .iter() .map(|record| { if let Some(body) = &record.body { debug!(body = %body, "Deserializing SQS record"); serde_json::from_str::(body.as_str()) } else { debug!("Record body is None when parsing SQS record"); Err(serde_json::Error::io(std::io::Error::new( std::io::ErrorKind::InvalidData, "Record body is None", ))) } }) .collect() } /// Main entrypoint for the Lambda. /// /// This code will be run only during INIT phase meaning variables instantiated here /// will potentially be re-used across invocations. /// /// Make sure to provide more detailed error results here as the default CloudWatch /// Logs only print out Runtime.Error without any context. The cleanest way is probably /// to use map_err()? on every fallible operation and log the error before returning it. /// /// More information [here](https://docs.aws.amazon.com/lambda/latest/dg/rust-handler.html#rust-shared-state) #[tokio::main] async fn main() -> Result<(), Error> { init_tracing(); // Load AWS config and create Secrets Manager client let config = aws_config::load_from_env().await; // Initialize RDS PostgreSQL client and SQS client and invoke our lambda debug!("Composing database connection details"); let (host, port, database_name, username, password) = get_db_connection_details(&config).await.map_err(|e| { error!("Failed to compose connection details: {:?}", e); e })?; debug!("Initializing database client"); let db_client = DbClient::new( &host, port, &database_name, &username, &password, "manager_lambda", "task_manager,recon,util,public", ) .await .map_err(|e| { error!("Failed initialize database client: {:?}", e); e })?; debug!("Initializing SQS queue client"); let queue_client = QueueClient::new(&config); let manager = Manager::new(db_client, queue_client); debug!("Starting Lambda runtime"); // run(service_fn(|sqs_event: LambdaEvent| { run(service_fn(|event: LambdaEvent| { let manager = &manager; async move { function_handler(event, manager).await } })) .await?; Ok(()) } ``` ### `infrastructure/task_manager/src/worker/src/main.rs` **Absolute path:** `/root/mrpf/infrastructure/task_manager/src/worker/src/main.rs` ```rust use aws_lambda_events::event::sqs::SqsEvent; use lambda_helpers::{get_db_connection_details, init_tracing}; use lambda_runtime::tracing::{debug, info, warn}; use lambda_runtime::{Error, LambdaEvent, run, service_fn}; use mrpf_core::DbClient; use mrpf_core::tasks::models::{Task, TaskContainerId, WorkerQueueMessage}; use mrpf_core::tasks::{ExecutionContext, QueueClient}; use mrpf_task_worker::handle_tasks; /// Main handler function for the Lambda /// /// # Arguments /// /// * `event` - LambdaEvent - The event that triggered the Lambda /// * `db` - DbClient - The database client to interact with the Job Manager database /// * `queue` - QueueClient - The queue client to interact with SQS queues async fn function_handler( sqs_event: LambdaEvent, context: &ExecutionContext, ) -> Result<(), Error> { if sqs_event.payload.records.len() > 1 { warn!( sqs_records = %sqs_event.payload.records.len(), "Received multiple records, but expecting only one. This might lead to increased latency, timeouts or memory consumption. Consider reducing the batch size in the SQS event source mapping configuration." ); } info!( "Converting SQS event to {} WorkerQueueMessage(s)", sqs_event.payload.records.len() ); let tasks = parse_records(sqs_event)?; let messages = handle_tasks(tasks, context).await?; if !messages.is_empty() { info!( message_count = %messages.len(), "Sending TaskManagerQueueMessage(s) to Task Manager", ); context.queue_client()?.send_mgr_messages(&messages).await?; } else { info!("No messages to send to Task Manager"); } Ok(()) } /// Try to deserialize a LambdaEvent into a vector of Tasks. /// /// Since we only have a single WorkerQueuMessage enum type StartTask, we unwrap the task here. /// If we need more in the future, we will return WorkerQueueMessage directly. /// /// # Arguments /// * `event` - LambdaEvent - The event that triggered the Lambda /// /// # Returns /// * `Result, serde_json::Error>` - A result containing a vector of Tasks or a serde_json::Error pub fn parse_records( event: LambdaEvent, ) -> Result, serde_json::Error> { let messages: Result, serde_json::Error> = event .payload .records .iter() .map(|record| { serde_json::from_str::( record.body.as_ref().expect("Failed to get body").as_str(), ) }) .collect(); let messages = messages?; let tasks: Vec<(TaskContainerId, Task)> = messages .into_iter() .map(|msg| match msg { WorkerQueueMessage::StartTask { task, root_container_id, } => (root_container_id, task), }) .collect(); Ok(tasks) } /// Main entrypoint for the Lambda. /// /// This code will be run only during INIT phase meaning variables instantiated here /// will potentially be re-used across invocations. /// /// More information [here](https://docs.aws.amazon.com/lambda/latest/dg/rust-handler.html#rust-shared-state) #[tokio::main] async fn main() -> Result<(), Error> { init_tracing(); let config = aws_config::load_from_env().await; let queue_client = QueueClient::new(&config); // // Initialize RDS PostgreSQL client and SQS client and invoke our lambda. match get_db_connection_details(&config).await { Ok((host, port, database_name, username, password)) => { // We're not properly leveraging the WorkerRequirements here as we have the same lambda code for handling // internal and external workers. Perhaps we can do something more with ExecutionContextBuilder eventually // but maybe thats not needed in Lambda. It could be more useful perhaps in more permanent compute environments? debug!( "Successfully retrieved database connection details, assuming we're dealing with an Internal Worker" ); let db = DbClient::new( &host, port, &database_name, &username, &password, "worker_lambda", "recon,task_manager,util,public", // TODO: perhaps we need to move this to the client itself? ) .await?; let context = ExecutionContext::builder() .db(db) .queue(queue_client) .build(); run(service_fn(|sqs_event: LambdaEvent| { let context_ref = &context; async move { function_handler(sqs_event, context_ref).await } })) .await?; } Err(_) => { debug!( "Could not get database connection details, assuming we're dealing with an External Worker" ); let context = ExecutionContext::builder().queue(queue_client).build(); run(service_fn(|sqs_event: LambdaEvent| { let context_ref = &context; async move { function_handler(sqs_event, context_ref).await } })) .await?; } } Ok(()) } ``` ### `promptios.md` **Absolute path:** `/root/mrpf/promptios.md` ```markdown # Goal Add a reusable, environment-driven overlay that shows **both** in-progress and error states. Errors must **only** disappear when the user taps an explicit **Dismiss** button (no auto-dismiss). Use latest iOS26/macOS26 SwiftUI & concurrency features, released in 2025. ## Scope & Constraints:** - Do not duplicate UI; the overlay must be attached **once** near the root via a `ViewModifier` inside a `ZStack`. - Reuse existing loading state from `MPFAPI`: **inspect `MPFAPI` and all its extensions** to find every `*isLoading` boolean and treat the overlay as busy if **any** is `true`. - Refactor API calls in `MPFAPI` extensions to **throw** errors (no `fatalError`, `preconditionFailure`, `try!`, force unwraps). Propagate thrown errors to the overlay. - Keep public API changes minimal; prefer adding throwing variants and migrating call sites. ## Acceptance criteria - Overlay shows a spinner whenever **any** `MPFAPI.*isLoading*` is `true`. - When an error is thrown, the overlay switches to an error card with message and a **Dismiss** button; it **does not** auto-dismiss. - API calls no longer crash; they throw and are surfaced in the overlay. - Overlay added once at the root; no duplicated error/progress UI in leaf views. - Interaction is blocked during progress (configurable), not during error. ## Style Swift 6 concurrency SwiftUI (iOS26/macOS26) New iOS/mac26 2025 Liquid glass effects where applicable `@MainActor` correctness ``` ### `scanners/mrpf_dns_resolver/DNS_RECON_TRICKS.md` **Absolute path:** `/root/mrpf/scanners/mrpf_dns_resolver/DNS_RECON_TRICKS.md` ```markdown # Tricks using DNS lookups for recon Read this for discovering S3 buckets using DNS enumeration: - ## So many different DNS responses. What are things that might give out interesting info? DNAME for instance? HTTPS? Any others? ## Lookups using PTR style domain names, we can find owners of IP ranges When doing reverse DNS lookups (PTR records) on IP addresses, sometimes the returned domain names can give clues about the owner of the IP range. eg. when looking up the Start of Authority (SOA) record for a PTR domain name like my transip vps, I get: dig 226.153.144.136.in-addr.arpa SOA Authority section: 153.144.136.in-addr.arpa. 1800 IN SOA ns0.transip.net. hostmaster.transip.nl. 2025101200 14400 1800 604800 86400 Seems like it indeed works BUT how does this handle it when we're dealing with smaller subnets like /25? According to RFC 2317 it should be done using CNAME records: either individual CNAME records for each IP address in the subnet, or a CNAME record with a range like this 130-131.128-255.153.144.136.in-addr.arpa. ``` ### `scanners/mrpf_dns_resolver/README.md` **Absolute path:** `/root/mrpf/scanners/mrpf_dns_resolver/README.md` ```markdown # MRPF DNS An ultra fast DNS client built on to of the mrpf_engine. Look at masscan for tips. We're not looking at HTTPS DNS scanning. That will be in mrpf_https_dns package and will use cloudflare's DNS over HTTPS API. ## Features - Support for HTTPS/Type65 DNS queries to uncover extra hidden endpoints and ports for applications. Not widely used but if we can find one it might be worth it. Check RFC Use dig google.com TYPE65 # SVCB dig google.com HTTPS # HTTPS RR. Also check Copilit question: What are all the different DNS types? Search all old and new RFC's. sort them by probability of being useful for bug bounty recon Answer: Absolutely—here’s a **field‑tested, recon‑oriented view of DNS RR types**, pulled from the IANA registry and the relevant RFCs, then **ranked by how often they pay off during bug bounty reconnaissance**. At the end I include pointers to the **complete/authoritative type list** and a quick way to query lesser‑known types safely (RFC 3597). --- ## How I ranked them (so you can adapt it) I used three criteria: 1) **Asset discovery value** (find new hostnames, vendors, ports, or pathways), 2) **Misconfig/leak potential** (info that hints at takeovers, weak email posture, internal names, test endpoints), 3) **Breadth of deployment** (how often you’ll actually see it in the wild). Authoritative references: **IANA DNS RR TYPE registry** and the RFCs that define each type.[1](https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml) --- ## Top targets for bug‑bounty recon (ranked) > Each line includes the defining RFC(s), why it helps, and one or two practical digs you can paste. 1) **TXT** — *RFC 1035* (container); **SPF** (*RFC 7208*), **DKIM** (*RFC 6376*), **DMARC** (*RFC 7489*, being revised) *Why:* Treasure chest for verification strings and email auth policy—often reveals **SaaS vendors**, **shadow senders**, and **internal hostnames** via `include:` chains; DMARC `rua=` endpoints can expose ops addresses.[2](https://www.rfc-editor.org/rfc/rfc1035)[3](https://www.rfc-editor.org/rfc/rfc7208)[4](https://www.rfc-editor.org/rfc/rfc6376)[5](https://www.rfc-editor.org/rfc/rfc7489) ``` dig +short TXT example.com dig +short TXT _dmarc.example.com ``` *Tip:* SPF RR type 99 is deprecated—policies live in **TXT** only.[3](https://www.rfc-editor.org/rfc/rfc7208)[6](https://blog.dnsimple.com/2025/07/discontinuing-spf-record-type/) 2) **CNAME (5)** — *RFC 1035* *Why:* Points at **back‑end providers** (AWS, Azure, GCP, Vercel, Fastly…), which can indicate **subdomain takeover** candidates if the target name is unprovisioned. Also exposes organizational vendor map.[2](https://www.rfc-editor.org/rfc/rfc1035) ``` dig +noall +answer www.example.com CNAME ``` 3) **MX (15)** — *RFC 1035* *Why:* Leaks **mail provider** (Google/Microsoft/Proofpoint/Postfix on-prem), sometimes **per‑region hostnames**. Cross‑check against SPF/DKIM/DMARC for alignment gaps and third‑party senders.[2](https://www.rfc-editor.org/rfc/rfc1035) ``` dig +short MX example.com ``` 4) **NS (2) & SOA (6)** — *RFC 1035* *Why:* Exposes **DNS hoster** and, via SOA MNAME/RNAME, sometimes **internal naming** or email aliases. Changing NS during migrations can leave dangling zones around.[2](https://www.rfc-editor.org/rfc/rfc1035) ``` dig NS example.com dig SOA example.com +noall +answer ``` 5) **SVCB (64) / HTTPS (65)** — *RFC 9460* *Why:* Modern goldmine. Can reveal **alternate endpoints**, **ports**, **ALPN** (h2/h3), and sometimes new hostnames; **HTTPS(65)** applies to web origins; **SVCB(64)** generalizes to other services. Frequently missed by recon scripts.[7](https://standards.globalspec.com/std/14647706/rfc-9460) ``` dig -t TYPE65 www.example.com +noall +answer dig -t TYPE64 _imap._tcp.example.com +noall +answer ``` 6) **SRV (33)** — *RFC 2782* *Why:* Structured service discovery for SIP/XMPP/LDAP/VoIP/etc.; yields **target hostnames and ports**; the “target” **must not be a CNAME** per the RFC—misconfigs are common.[8](https://www.rfc-editor.org/rfc/rfc2782)[9](https://serverfault.com/questions/765743/published-srv-records-pointing-to-cname-alias-in-violation-of-rfc-2782) ``` dig _sip._tcp.example.com SRV +noall +answer ``` 7) **A (1) / AAAA (28)** — *RFC 1035*, *RFC 3596* *Why:* Baseline mapping, but still useful to **pivot via reverse PTR**, see anycast footprints, and compare **v4/v6 parity** (IPv6 sometimes exposes different edges).[2](https://www.rfc-editor.org/rfc/rfc1035)[10](https://www.rfc-editor.org/rfc/rfc3596) ``` dig +short A app.example.com dig +short AAAA app.example.com ``` 8) **PTR** — *RFC 1035*, operational guidance *RFC 1912* *Why:* Reverse DNS often leaks **hosting provider**, **legacy names**, or **internal conventions**; good for mapping IP assets to tenants. (Forward‑confirmed reverse DNS is an operational best‑practice; mismatches can be probative.)[2](https://www.rfc-editor.org/rfc/rfc1035)[11](https://datatracker.ietf.org/doc/html/rfc1912) ``` dig -x 203.0.113.10 +short ``` 9) **TLSA (DANE)** — *RFC 6698* (+ usage docs) *Why:* When present (mail/IMAP/SMTP/HTTPS), it identifies **TLS cert/key bindings** and **ports** via `_port._proto.name` labels; can reveal **non‑standard TLS endpoints**. (Requires DNSSEC for the security model, but you can still recon without validating.)[12](https://www.rfc-editor.org/rfc/rfc6698)[13](https://archive.org/details/rfc7673) ``` dig _443._tcp.login.example.com TLSA +noall +answer ``` 10) **CAA (257)** — *RFC 8659* (obsoletes 6844) *Why:* Lists **which CAs may issue** certs and optional **iodef** reporting URIs (great for security contact discovery). Look for gaps across subdomains (wildcard vs. leaf).[14](https://www.rfc-editor.org/rfc/rfc8659.html) ``` dig CAA example.com +noall +answer ``` 11) **NAPTR (35)** — *RFC 3403* *Why:* Regex‑based rewrite to URIs (ENUM, SIP); can chain into SRV and uncover **telephony/SIP real estate** and **migration breadcrumbs**. Less common but high‑signal when present.[15](https://datatracker.ietf.org/doc/html/rfc3403) ``` dig NAPTR example.com +noall +answer ``` 12) **DNAME (39)** — *RFC 6672* *Why:* **Subtree** redirection (like a CNAME for everything below a label). Handy to discover **renames/mergers** or staged cutovers; resolvers synthesize CNAMEs alongside.[16](https://www.rfc-editor.org/rfc/rfc6672) ``` dig DNAME legacy.example.com +noall +answer ``` 13) **URI (256)** — *RFC 7553* *Why:* Maps to direct URIs; not widely deployed for web, but some stacks (e.g., Kerberos) use it. When present, it can disclose **canonical service URLs** with paths.[17](https://datatracker.ietf.org/doc/html/rfc7553) ``` dig URI _ftp._tcp.example.com +noall +answer ``` 14) **SSHFP** — *RFC 4255* (+ algorithm updates) *Why:* Indicates **hosts offering SSH**, gives **key algorithms**; with DNSSEC it’s for trust, but for recon it reveals **fleet naming** and SSH surface. Adoption is still low.[18](https://datatracker.ietf.org/doc/html/rfc4255) ``` dig SSHFP host.example.com +noall +answer ``` 15) **LOC (29), HINFO (13), RP (17)** — *RFC 1876*, *RFC 1035*, *RFC 1183* *Why:* Rare now, but if present they can leak **geo**, **host platform**, or a **responsible-person mailbox**—all pivot points. (Treat as serendipity, not expectation.)[2](https://www.rfc-editor.org/rfc/rfc1035) --- ## Medium‑yield, situational - **DNSSEC set (DNSKEY, DS, RRSIG, NSEC/NSEC3):** Not directly “asset” records, but **NSEC** (not NSEC3) can allow **zone walking** to enumerate labels; **NSEC3** reduces this, though operational choices (iterations/salt) matter. Check if a zone is walkable.[19](https://www.rfc-editor.org/rfc/rfc4034)[20](https://www.rfc-editor.org/rfc/rfc5155) ``` # Try to walk (works only on some NSEC‑signed zones) dig nonexistent.example.com A +dnssec +noall +authority ``` - **CDS (59)/CDNSKEY (60):** Used for automating DS in parents; not often recon‑relevant but signal DNSSEC maturity.[1](https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml) - **HIP (55), IPSECKEY (45):** May expose VPN/host identity info or gateways; uncommon, but interesting when present.[18](https://datatracker.ietf.org/doc/html/rfc4255)[21](https://www.rfc-editor.org/rfc/rfc4025.html) --- ## Low‑yield (for recon), but good to recognize - **A/AAAA glue, DHCID (49), EUI48/64 (108/109), APL (42):** Niche/ops signal only.[1](https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml) - **DLV (32769):** Historic DNSSEC mechanism; seeing it suggests very old configs.[22](https://en.wikipedia.org/wiki/List_of_DNS_record_types) --- ## Protocol “meta” you’ll see in `dig` (useful context) - **OPT (EDNS(0) pseudo‑RR, type 41)** — enables larger UDP answers, DO bit for DNSSEC, and carries options (like ECS). `dig` shows this as the **“OPT PSEUDOSECTION”**. Good to know when inspecting traces.[23](https://www.rfc-editor.org/rfc/rfc6891)[24](https://serverfault.com/questions/1018425/what-does-opt-pseudosection-mean-in-dig-response) - **EDNS Client Subnet (ECS, RFC 7871)** — some resolvers send partial client subnet to authors; affects **geo‑steering** and can explain why answers change. Recon note: test from multiple vantage points.[25](https://www.rfc-editor.org/rfc/rfc7871) --- ## Quick commands & patterns that routinely surface new attack surface - **Check apex and `www` for HTTPS(65) and ALPNs** (HTTP/3 rarely mirrored everywhere): ``` dig -t TYPE65 example.com +noall +answer dig -t TYPE65 www.example.com +noall +answer ``` *SVCB param key meanings here (alpn, port, ipv4hint, ech)*: see IANA SVCB key registry.[26](https://www.ietf.org/archive/id/draft-ietf-dnsop-nsec3-guidance-10.html) - **Enumerate common SRV services** (SIP, XMPP, LDAP, Kerberos, Autodiscover): ``` for s in _sip._tcp _sip._udp _xmpp-client._tcp _xmpp-server._tcp _ldap._tcp _kerberos._udp _autodiscover._tcp; do dig $s.example.com SRV +short done ``` (SRV format and semantics in RFC 2782.)[8](https://www.rfc-editor.org/rfc/rfc2782) - **Email posture map:** ``` dig +short MX example.com dig +short TXT example.com # SPF in TXT dig +short TXT _dmarc.example.com # DMARC policy/reporting # DKIM selectors vary: s1/s2/google/selector1... look at headers or vendor docs dig +short TXT selector1._domainkey.example.com ``` (SPF/DKIM/DMARC specs and semantics.)[3](https://www.rfc-editor.org/rfc/rfc7208)[4](https://www.rfc-editor.org/rfc/rfc6376)[5](https://www.rfc-editor.org/rfc/rfc7489) - **Pivot on CNAMEs and PTRs:** ``` dig +noall +answer foo.example.com CNAME dig -x +short ``` (CNAME/PTR defined in RFC 1035; PTR ops in RFC 1912.)[2](https://www.rfc-editor.org/rfc/rfc1035)[11](https://datatracker.ietf.org/doc/html/rfc1912) --- ## Need the *complete* list of DNS types? The **IANA DNS parameters registry** is the canonical, always‑current list with links to every defining RFC (includes classic, modern, and experimental types). Use it when you encounter an unknown **TYPE###** in outputs.[1](https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml) > **Tip:** Per **RFC 3597**, any resolver that doesn’t know a type should still carry it and tools can query by number (e.g., `TYPE65` for HTTPS). This is why `dig -t TYPE65` works everywhere, even if `HTTPS` mnemonic isn’t recognized.[27](https://datatracker.ietf.org/doc/html/rfc3597) --- ## Recon pitfalls & pro‑tips - **SRV targets MUST NOT be CNAMEs** (per RFC 2782). Finding them is a good misconfig clue and sometimes breaks clients in ways that leak fallbacks.[8](https://www.rfc-editor.org/rfc/rfc2782)[9](https://serverfault.com/questions/765743/published-srv-records-pointing-to-cname-alias-in-violation-of-rfc-2782) - **NSEC vs NSEC3:** If you see NSEC (not NSEC3), **zone walking** may enumerate labels. If NSEC3 is used with zero iterations/salt choices per guidance, walking is impractical, but you can still learn that DNSSEC is deployed.[19](https://www.rfc-editor.org/rfc/rfc4034)[20](https://www.rfc-editor.org/rfc/rfc5155) - **CAA:** Look for **gaps**—e.g., wildcard allowed for `*.example.com` but a sub‑delegate has no CAA (or different CA), which can hint at issuance hygiene issues or forgotten sub‑zones.[14](https://www.rfc-editor.org/rfc/rfc8659.html) - **HTTPS(65)/SVCB(64):** Check both the **apex** and **`www`**; adoption rates differ and sometimes only one publishes ALPN/ECH hints that reveal **alt endpoints**.[7](https://standards.globalspec.com/std/14647706/rfc-9460) --- ## Want a CSV of “recon‑useful types” with RFCs and `dig` examples? Say the word and I’ll generate a **ready‑to‑import CSV** (type, RFC, why it matters, sample queries) or a tiny **Bash/Python** that walks a domain and extracts **A/AAAA/CNAME/MX/NS/SOA/TXT(SPF,DKIM,DMARC)/SRV/SVCB/HTTPS/TLSA/CAA/NAPTR/SSHFP** plus notable SVCB params. (I’ll tailor it to your workflow and favorite tools.) --- ### Sources (selected) - **IANA DNS RR TYPE registry** (authoritative list and references).[1](https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml) - **Core DNS**: RFC 1035 (A/CNAME/MX/NS/SOA/TXT/PTR/HINFO…).[2](https://www.rfc-editor.org/rfc/rfc1035) - **AAAA (IPv6)**: RFC 3596.[10](https://www.rfc-editor.org/rfc/rfc3596) - **SRV**: RFC 2782 (target not CNAME).[8](https://www.rfc-editor.org/rfc/rfc2782)[9](https://serverfault.com/questions/765743/published-srv-records-pointing-to-cname-alias-in-violation-of-rfc-2782) - **SVCB/HTTPS**: RFC 9460 (service binding & HTTPS RRs).[7](https://standards.globalspec.com/std/14647706/rfc-9460) - **TLSA (DANE)**: RFC 6698 (+ SRV usage guidance).[12](https://www.rfc-editor.org/rfc/rfc6698)[13](https://archive.org/details/rfc7673) - **CAA**: RFC 8659 (obsoletes 6844).[14](https://www.rfc-editor.org/rfc/rfc8659.html) - **NAPTR**: RFC 3403.[15](https://datatracker.ietf.org/doc/html/rfc3403) - **DNAME**: RFC 6672.[16](https://www.rfc-editor.org/rfc/rfc6672) - **SSHFP**: RFC 4255.[18](https://datatracker.ietf.org/doc/html/rfc4255) - **PTR ops**: RFC 1912.[11](https://datatracker.ietf.org/doc/html/rfc1912) - **DNSSEC RRs**: RFC 4034 (DNSKEY/DS/RRSIG/NSEC) & RFC 5155 (NSEC3).[19](https://www.rfc-editor.org/rfc/rfc4034)[20](https://www.rfc-editor.org/rfc/rfc5155) - **SPF/DKIM/DMARC**: RFC 7208, RFC 6376, RFC 7489 (DMARCbis in progress).[3](https://www.rfc-editor.org/rfc/rfc7208)[4](https://www.rfc-editor.org/rfc/rfc6376)[5](https://www.rfc-editor.org/rfc/rfc7489) - **EDNS(0)/OPT**: RFC 6891; **ECS**: RFC 7871.[23](https://www.rfc-editor.org/rfc/rfc6891)[25](https://www.rfc-editor.org/rfc/rfc7871) - **Unknown type handling**: RFC 3597 (use `TYPE###`).[27](https://datatracker.ietf.org/doc/html/rfc3597) --- If you share a target (or a sanitized sample), I can **run a focused sweep** and return a **shortlist of new domains/endpoints** worth probing next. ``` ### `scanners/mrpf_dns_resolver/src/lib.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_dns_resolver/src/lib.rs` ```rust pub mod progress; mod receive; pub mod scanner; mod transmit; pub use progress::DnsMessage; pub use scanner::DnsResolver; ``` ### `scanners/mrpf_dns_resolver/src/progress.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_dns_resolver/src/progress.rs` ```rust use mrpf_engine::ProgressHandler; use mrpf_engine::network::dns::DnsResourceRecord; use std::collections::{HashMap, HashSet}; use std::sync::mpsc::Receiver; /// Progress Messages sent during a TCP SYN scan. pub enum DnsMessage { // A found TCP Open Port and the corresponding IP address Result(DnsResourceRecord), /// Receive statistics RxStats { packet_count: u32, full_read_buffer_count: u8, }, /// Transmit statistics TxStats { packet_count: u32, full_write_buffer_count: u8, }, /// The scan is finished Finished, } /// Headless ProgressHandler for the DNS resolver /// /// The handler will ensure the resolver is stopped and the results are collected. /// No other progress messages will be parsed since we're running in a headless manner. #[derive(Debug)] pub struct DnsProgressHandler; impl ProgressHandler for DnsProgressHandler { type Message = DnsMessage; type Results = HashMap>>; fn handler(&mut self, receiver: Receiver) -> Self::Results { let mut results: Self::Results = HashMap::new(); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { eprintln!("Failed to receive status update: {}", e); break; } }; match message { // Being very pragmatic here and only collecting results that // we typically care about. Strictly speaking we might want // to collect a few more properties on some records, as well // as use better types like Ipv4Addr for ips. For now, // this is sufficient. DnsMessage::Result(result) => match result { DnsResourceRecord::A { owner_name, addr } => { results .entry(owner_name) .or_default() .entry("A".to_string()) .or_default() .insert(addr.to_string()); } DnsResourceRecord::AAAA { owner_name, addr } => { results .entry(owner_name) .or_default() .entry("AAAA".to_string()) .or_default() .insert(addr.to_string()); } DnsResourceRecord::CNAME { owner_name, cname } => { results .entry(owner_name) .or_default() .entry("CNAME".to_string()) .or_default() .insert(cname); } DnsResourceRecord::MX { owner_name, exchange, .. } => { results .entry(owner_name) .or_default() .entry("MX".to_string()) .or_default() .insert(exchange); } DnsResourceRecord::PTR { owner_name, ptrdname, } => { results .entry(owner_name) .or_default() .entry("PTR".to_string()) .or_default() .insert(ptrdname); } _ => {} }, DnsMessage::Finished => break, _ => {} } } results } } ``` ### `scanners/mrpf_dns_resolver/src/receive.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_dns_resolver/src/receive.rs` ```rust use super::progress::DnsMessage; use mrpf_engine::network::arp::handle_arp_requests; use mrpf_engine::network::dns::parse_dns_response; use mrpf_engine::{ScannerConfig, unwrap_or_continue}; use pnet::datalink::DataLinkReceiver; use pnet::packet::Packet; use pnet::packet::ethernet::{EtherTypes, EthernetPacket}; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::udp::UdpPacket; use std::sync::mpsc::Sender; use std::time::Duration; use tracing::debug; /// The main receive handler for the network engine. /// /// The function runs in a loop and receives packets from the network interface. /// /// It uses the SYN cookie algorithm to determine if the received packet belongs /// to the scanner. pub(crate) fn receive_handler( mut rx: Box, packet_tx: Sender>, progress_tx: Sender, config: ScannerConfig, ) { let mut packet_count = 0; let full_read_buffer_count = 0; let progress_interval: Duration = Duration::from_secs(1); // We explicitly set the last progress to half the interval to try and // be out of sync with the transmit thread. This should spread the load // of the progress updates a bit more. let mut last_progress = std::time::Instant::now() - (progress_interval / 2); loop { match rx.next() { Ok(packet) => { // Determine if the received packet belongs to our scanner // TODO: I want to see if I can do ARP spoofing so I can filter packets to the spoofed // ARP address early here to determine it comes from the scanner. This will // allow us to skip the TCP cookie generation for a proper comparison. let ethernet_packet = unwrap_or_continue!(EthernetPacket::new(packet)); if handle_arp_requests(ðernet_packet, &config, &packet_tx) { // ARP request handled continue; } if ethernet_packet.get_destination() != config.src_mac || ethernet_packet.get_ethertype() != EtherTypes::Ipv4 { // Packet was not destined for us continue; } let ipv4_packet = unwrap_or_continue!(Ipv4Packet::new(ethernet_packet.payload())); if ipv4_packet.get_destination() != config.src_ip { // This is especially useful if we have a dedicated (e.g. spoofed) IP for the scanner continue; } if ipv4_packet.get_next_level_protocol() != IpNextHeaderProtocols::Udp { // println!("Received non-TCP packet"); continue; } let udp_packet = unwrap_or_continue!(UdpPacket::new(ipv4_packet.payload())); let resource_records = unwrap_or_continue!(parse_dns_response(udp_packet.payload())); for record in resource_records { debug!(record = ?record, "Received DNS record"); progress_tx.send(DnsMessage::Result(record)).unwrap(); } packet_count += 1; } Err(e) => { panic!("Failed to receive packet: {}", e) } } // Send progress update every second if last_progress.elapsed() > progress_interval { if progress_tx .send(DnsMessage::RxStats { packet_count, full_read_buffer_count, }) .is_ok() { packet_count = 0 } last_progress = std::time::Instant::now(); } } } ``` ### `scanners/mrpf_dns_resolver/src/scanner.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_dns_resolver/src/scanner.rs` ```rust use super::progress::{DnsMessage, DnsProgressHandler}; use super::receive::receive_handler; use super::transmit::transmit_handler; use mrpf_engine::ProgressHandler; use mrpf_engine::ScannerConfig; use mrpf_engine::network::arp::get_router_arp_address; use mrpf_engine::network::datalink::create_datalink_channel; use mrpf_engine::network::dns::DnsQueryType; use std::collections::{HashMap, HashSet}; use std::sync::mpsc; use std::thread; /// The TCP Syn Scanner. /// /// It is responsible for managing the network /// interfaces, sending and receiving packets, and handling status updates. /// /// The core of the engine is based on moving the problem of quickly scanning a large volume of /// network hosts from an IO bound to a CPU bound problem. This is done by separating the transmit /// and receive of TCP/IP network stack to two separate threads. /// /// Correlating send and receive packets is being done by SYN cookies instead of relying on sockets /// per connection. /// /// TODO: /// Spoofing MAC still doesn't work. The router doesn't seem to respond to my arp request for router address /// we're only getting it implicitly after a while due to the thing broadcasting itself. /// /// *TODO* This should become only public inside the crate. Our consumers should /// instead be using the various scanner implementations. /// /// # Example /// /// ```ignore /// use mrpf::network::datalink::get_interface; /// use mrpf::{ScannerConfig, DnsResolver}; /// use std::net::Ipv4Addr; /// /// let interface = get_interface(None).unwrap(); /// let src_mac = interface.mac.unwrap(); /// /// let config = ScannerConfig { /// interface, /// src_ip: Ipv4Addr::new(192, 168, 1, 182), /// src_mac, /// router_ip: Ipv4Addr::new(192, 168, 1, 1), /// }; /// /// let scanner = DnsResolver::new(config); /// let ipv4_addresses = vec![Ipv4Addr::new(192, 168, 1, 100)]; /// let ports = vec![80, 443]; /// /// let results = scanner.run(ipv4_addresses, ports); /// for (ip, open_ports) in results { /// println!("IP: {}, Open Ports: {:?}", ip, open_ports); /// } /// ``` pub struct DnsResolver { /// Scanner Configuration. /// /// This holds configuration data like the source MAC and IP address, the router IP address, /// and the network interface to use. pub config: ScannerConfig, } impl DnsResolver { /// Initialize engine pub fn new(config: ScannerConfig) -> DnsResolver { // Mac spoofing still not working, perhaps it's something to do with // unicast/multicat + local/manufactor assigned code? this is encoded // in the first and second significant bits of the first octet // let src_mac: MacAddr = MacAddr::new(0x00, 0xad, 0xde, 0xad, 0xde, 0xad); DnsResolver { config } } /// Run a TCP Syn Scan with the default progress handler protocol. /// /// This will open up the datalink channels, spawn the sending and receiving threads, /// and handle status updates. /// /// # Arguments /// /// * `ipv4_addresses` - The list of IPv4 addresses to scan. /// * `ports` - The list of ports to scan. /// /// # Returns /// /// A HashMap containing the results of the scan. pub fn run( &self, domain_names: HashSet, query_type: DnsQueryType, ) -> HashMap>> { let progress_handler = DnsProgressHandler; self.run_with_handle(domain_names, query_type, progress_handler) } /// Run a TCP Syn Scan with a specific progress handler protocol. /// /// This will open up the datalink channels, spawn the sending and receiving threads, /// and handle status updates. /// /// # Arguments /// /// * `ipv4_addresses` - The list of IPv4 addresses to scan. /// * `ports` - The list of ports to scan. /// * `progress_handler` - The progress handler function to handle status updates. /// /// # Returns /// /// A JoinHandle containing the results of the scan. pub fn run_with_handle< P: ProgressHandler< Message = DnsMessage, Results = HashMap>>, >, >( &self, domain_names: HashSet, query_type: DnsQueryType, mut progress_handler: P, ) -> HashMap>> { let (mut tx, mut rx) = create_datalink_channel(&self.config.interface, 4096 * 4, 4096 * 4); // Retrieve the router mac address let router_mac = get_router_arp_address( self.config.src_mac, self.config.src_ip, self.config.router_ip, &mut *tx, &mut *rx, ) .expect("Router MAC address should be found"); // Create a channel and separate thread for handling progress updates let (progress_tx, progress_rx) = mpsc::channel(); let progress_handle = thread::Builder::new() .name("progress_thread".to_string()) .spawn(move || progress_handler.handler(progress_rx)) .expect("Progress thread should have been spawned"); // Create a channel for sending packets from the receive thread to the transmit thread let (packet_tx, packet_rx) = mpsc::channel(); // Create a thread for receiving packets let config = self.config.clone(); let progress_tx_clone = progress_tx.clone(); let receive_handle = thread::Builder::new() .name("receive_thread".to_string()) .spawn(move || { receive_handler(rx, packet_tx, progress_tx_clone, config); }) .expect("Receive thread should have been spawned"); // Create a thread for sending packets let config = self.config.clone(); let transmit_handle = thread::Builder::new() .name("transmit_thread".to_string()) .spawn(move || { transmit_handler( tx, packet_rx, progress_tx, config, router_mac, query_type, domain_names, ); }) .expect("transmit thread should have been spawned"); // When the transmit handle is finished, we can drop the receive handle // and return the results from the progress_handle. transmit_handle.join().unwrap(); // TODO: This actually doesn't stop the receive thread. We will have to stop the thread by // passing a AtomicBool Arc to the transmit and receive thread so we can signal it's done. // There might be other control flags we'd like to set such as pausing the loop, changing // the rate limiting settings, etc. Maybe we can have a shared struct that the progress, // transmit and receive threads use to communicate with each other. drop(receive_handle); progress_handle.join().unwrap() } } ``` ### `scanners/mrpf_dns_resolver/src/transmit.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_dns_resolver/src/transmit.rs` ```rust use super::progress::DnsMessage; use mrpf_engine::ScannerConfig; use mrpf_engine::network::datalink::send_packet; use mrpf_engine::network::dns::{DnsQueryType, create_dns_query_packet_v4}; use pnet::datalink::DataLinkSender; use pnet::util::MacAddr; use rand::Rng; use std::collections::HashSet; use std::net::Ipv4Addr; use std::sync::mpsc::{Receiver, Sender}; use std::time::{Duration, Instant}; use tracing::info; // Interval of status updates and reset of our packet counts const PROGRESS_INTERVAL: Duration = Duration::from_secs(1); /// The main function responsible for sending packets. /// /// To quote Masscan: /// /// ```quote /// THIS IS WHERE ALL THE EXCITEMENT HAPPENS!!!! /// 90% of CPU cycles are in this function. /// ``` pub(crate) fn transmit_handler( mut tx: Box, _packet_rx: Receiver>, progress_tx: Sender, mut config: ScannerConfig, // TODO move router mac to the builder? Then we only have to do it once // instead of each scan and it makes it easier to pass on values. router_mac: MacAddr, query_type: DnsQueryType, domain_names: HashSet, ) { // Progress counters let mut packets_per_second: u32 = 0; let mut full_write_buffer_count = 0; let mut last_progress = Instant::now() - PROGRESS_INTERVAL; // Start rate limiter config.rate_limiter.start(); let mut rng = rand::rng(); let src_port_range = 32768..=65535; let src_port: u16 = rng.random_range(src_port_range.clone()); let resolver = Ipv4Addr::new(8, 8, 8, 8); // TMP just use google dns for domain_name in domain_names { // Construct DNS query packet (A record) let dns_request = if query_type == DnsQueryType::PTR { let domain_name = match mrpf_engine::network::dns::domain_name_to_ptr(&domain_name) { Some(name) => name, None => { info!("Skipping invalid PTR domain name: {}", domain_name); continue; } }; info!("Sending PTR query for: {}", domain_name); info!("NOTE: SOMEHOW I'M NOT GETTING RESPONSES BACK FOR THESE RECORDS> TODO FIX"); create_dns_query_packet_v4( &config, &router_mac, src_port, &resolver, 53, &domain_name, query_type, ) } else { create_dns_query_packet_v4( &config, &router_mac, src_port, &resolver, 53, &domain_name, query_type, ) }; if let Some((_, packet)) = dns_request { config.rate_limiter.wait_for_token(); send_packet(&mut tx, &packet, &mut full_write_buffer_count); packets_per_second += 1; } // Send status updates to our process handler and reset counters. if last_progress.elapsed() > PROGRESS_INTERVAL { send_progress_update( progress_tx.clone(), &mut packets_per_second, &mut full_write_buffer_count, ); last_progress = Instant::now(); } } std::thread::sleep(Duration::from_secs(10)); // Wait a bit to allow last packets to be received // Send final statistic update send_progress_update( progress_tx.clone(), &mut packets_per_second, &mut full_write_buffer_count, ); // Scan is finished, signal our process handler progress_tx .send(DnsMessage::Finished) .expect("Send finished message"); } /// Send transmit progress update #[inline] fn send_progress_update( progress_tx: Sender, packets_per_second: &mut u32, full_write_buffer_count: &mut u8, ) { progress_tx .send(DnsMessage::TxStats { packet_count: *packets_per_second, full_write_buffer_count: *full_write_buffer_count, }) .unwrap(); *packets_per_second = 0; *full_write_buffer_count = 0; } ``` ### `scanners/mrpf_engine/benches/intranges.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/benches/intranges.rs` ```rust use criterion::{Criterion, criterion_group, criterion_main}; use mrpf_engine::{PortRange, PortRanges}; fn bench_add_range(c: &mut Criterion) { c.bench_function("Add range to Ports", |b| { b.iter(|| { let mut range_list = PortRanges::new(vec![PortRange::new(10, 20)]); range_list.add_range(PortRange::new(15, 25)); }) }); } fn bench_add_number(c: &mut Criterion) { c.bench_function("Add number to Ports", |b| { b.iter(|| { let mut range_list = PortRanges::new(vec![PortRange::new(10, 20)]); range_list.add_port(21); }) }); } fn bench_iteration(c: &mut Criterion) { c.bench_function("Iterate over Ports", |b| { b.iter(|| { let range_list = PortRanges::new(vec![PortRange::new(10, 20), PortRange::new(30, 40)]); for _ in range_list {} }) }); } criterion_group!(benches, bench_add_range, bench_add_number, bench_iteration); criterion_main!(benches); ``` ### `scanners/mrpf_engine/optimize_ip_calc.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/optimize_ip_calc.rs` ```rust // Optimizations that we could look into // Obviously this is something that people have looked into a LOT, just want the confirmation that pnet is // not taking shortcuts here. // There's actually a RFC about calculating the checksum: // https://www.ietf.org/rfc/rfc1071.txt // the ip checksum calc is taking a significant amount of time, most of it spent in the // sum_be_words function. We could look into optimizing this function to speed up the // overall calculation. /// Sum all words (16 bit chunks) in the given data. The word at word offset /// `skipword` will be skipped. Each word is treated as big endian. fn sum_be_words(data: &[u8], skipword: usize) -> u32 { if data.len() == 0 { return 0; } let len = data.len(); let mut cur_data = &data[..]; let mut sum = 0u32; let mut i = 0; while cur_data.len() >= 2 { if i != skipword { sum += u16::from_be_bytes(cur_data[0..2].try_into().expect( "Should be safe to unwrap because we verified there are at least 2 bytes", )) as u32; } cur_data = &cur_data[2..]; i += 1; } // If the length is odd, make sure to checksum the final byte if i != skipword && len & 1 != 0 { sum += (data[len - 1] as u32) << 8; } sum } // Copilot said this could optimize it: // Yes, there are a few optimizations you can apply to this function to improve its performance: // - Avoid Slicing in the Loop: Instead of slicing cur_data in each iteration, use an index to access elements directly. This reduces the overhead of creating new slices. // - Use Iterators: Rust’s iterators can be more efficient and expressive. You can use chunks_exact to process the data in 16-bit chunks. // - Unroll the Loop: Manually unrolling the loop can reduce the overhead of loop control, though this is more of a micro-optimization. // Here’s an optimized version of your function: // // running the benchmark it seems it's slower than the original function fn sum_be_words_optimized(data: &[u8], skipword: usize) -> u32 { if data.is_empty() { return 0; } let len = data.len(); let mut sum = 0u32; let mut i = 0; // Process 16-bit chunks for chunk in data.chunks_exact(2) { if i != skipword { let word = u16::from_be_bytes([chunk[0], chunk[1]]); sum += word as u32; } i += 1; } // Handle the remainder byte if the length is odd if len % 2 != 0 && i != skipword { sum += (data[len - 1] as u32) << 8; } sum } // Here is a new optimized function after I tolkd it to look at the RFC // still slower but not as slow as the initial 'optimized' function fn sum_be_words_optimized_2(data: &[u8], skipword: usize) -> u32 { if data.is_empty() { return 0; } let len = data.len(); let mut sum = 0u32; let mut i = 0; // Process 16-bit chunks while i < len / 2 { if i != skipword { let word = u16::from_be_bytes([data[2 * i], data[2 * i + 1]]); sum = sum.wrapping_add(word as u32); } i += 1; } // Handle the remainder byte if the length is odd if len % 2 != 0 && i != skipword { sum = sum.wrapping_add((data[len - 1] as u32) << 8); } // Add end-around carry while (sum >> 16) != 0 { sum = (sum & 0xFFFF) + (sum >> 16); } sum } ``` ### `scanners/mrpf_engine/src/config.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/config.rs` ```rust use crate::ratelimit::RateLimiter; use pnet::datalink::NetworkInterface; use pnet::util::MacAddr; use rand::Rng; use std::net::Ipv4Addr; use std::time::Duration; use typed_builder::TypedBuilder; /// Configuration values for the scanner. /// /// The configuration should be generic for any type of scanner that is implemented. /// This means it should be able to work with any type of scanner, whether it is using /// the 'masscan' style of sparate threads for rx and tx, or using a higher level /// library like hyper. /// /// TODO: instead of Clone, can I just define Send and Sync to share it over threads? /// The config should be immutable after creation, so it should be safe to share /// I would have to create getter methods for the fields. #[derive(Debug, PartialEq, Clone, TypedBuilder)] pub struct ScannerConfig { /// The network interface to send and listen for packets on. pub interface: NetworkInterface, /// The source MAC address to use for sending packets. As per the source IP /// address, this will be spoofed to the router MAC address. pub src_mac: MacAddr, /// The source IP address to use for sending packets. /// /// This can be set to a non-active IP address within the current subnet. /// The engine will advertise its IP address to the router using ARP /// to spoof the IP address. /// /// Spoofing the IP address makes it easier and more /// performant to determine if a packet belongs to the scanner in the /// receive thread, as we don't have to generate a syn cookie to /// determine if a packet is a response to a packet we sent. pub src_ip: Ipv4Addr, /// The router IP address to use for sending packets. The scanner will /// automatically retrieve the Mac address of the router and use it for /// sending packets. pub router_ip: Ipv4Addr, /// Entropy used to for our randomization of SYN cookies #[builder(default = { let mut rng = rand::rng(); rng.random::() })] pub entropy: u64, /// The timeout for establishing a connection #[builder(default = Duration::from_secs(5))] pub connection_timeout: Duration, /// The rate limiter for the scanner #[builder(default = RateLimiter::default())] pub rate_limiter: RateLimiter, } #[cfg(test)] mod tests { use super::*; use pnet::datalink::interfaces; fn get_test_interface() -> NetworkInterface { interfaces() .into_iter() .next() .expect("No network interfaces found") } #[test] fn test_engine_config_builder() { let interface = get_test_interface(); let src_mac = MacAddr::new(0, 1, 2, 3, 4, 5); let src_ip = Ipv4Addr::new(192, 168, 1, 1); let router_ip = Ipv4Addr::new(192, 168, 1, 254); let entropy = 12345; let config = ScannerConfig::builder() .interface(interface.clone()) .src_mac(src_mac) .src_ip(src_ip) .router_ip(router_ip) .entropy(entropy) .build(); assert_eq!(config.interface, interface); assert_eq!(config.src_mac, src_mac); assert_eq!(config.src_ip, src_ip); assert_eq!(config.router_ip, router_ip); assert_eq!(config.entropy, entropy); } } ``` ### `scanners/mrpf_engine/src/connections/mod.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/connections/mod.rs` ```rust pub mod tcp; pub use tcp::TcpReassemblyBuffer; /// A Connection type used in the ConnectionTable pub trait Connection { /// Add data to an existing connection /// /// The given data will be added to the connection buffer. /// /// # Arguments /// /// * `data` - The data to add to the connection /// * `seq` - The sequence number of the data /// /// # Returns /// /// * `bool` - True if the data was added to the connection, false if the connection does not exist fn add_segment(&mut self, data: &[u8], seq: u32); /// Get the reassembled data for a connection /// /// After processing the reassembled data, clear_buffer should be called before /// processing further packets. /// /// # Arguments /// * `ip_address` - The IP address of the connection /// * `port` - The port of the connection /// /// # Returns /// * `Option>` - The decrypted buffer for the connection, or None if the TCP segment is not yet complete /// or the connection does not exist. fn get_reassembled_buffer(&mut self) -> Option>; /// Retrieve the next sequence number for sending data fn next_seq(&self) -> u32; } ``` ### `scanners/mrpf_engine/src/connections/tcp.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/connections/tcp.rs` ```rust use super::Connection; use std::collections::{BTreeMap, HashMap}; /// Struct to hold a table of TCP connections. /// /// This is used to hold the data for all the TCP connections that are currently /// being processed. It is used to hold received data until the protocol above /// (e.g. HTTP) can fully process it. /// /// TODO: Should we check the table periodically to remove connections that have been idle /// for too long? pub struct TcpConnectionTable { connections: HashMap, } impl TcpConnectionTable { /// Initialize new connection table pub fn new() -> Self { Self { connections: HashMap::new(), } } /// Add a new connection to the table. /// /// # Arguments /// /// * `syn_cookie` - The SYN cookie of the connection /// * `connection` - The connection to add to the table pub fn insert(&mut self, syn_cookie: u32, connection: T) { self.connections.insert(syn_cookie, connection); } /// Get the associated TlsConnection pub fn get(&mut self, syn_cookie: u32) -> Option<&mut T> { self.connections.get_mut(&syn_cookie) } /// Remove a connection from the table. pub fn remove(&mut self, syn_cookie: u32) { self.connections.remove(&syn_cookie); } } impl Default for TcpConnectionTable { fn default() -> Self { Self::new() } } /// A single TCP Connection without any upper layer protocol. pub struct TcpConnection { /// Holds the current TCP segments buffer: TcpReassemblyBuffer, // The next sequence number our connection will use for sending packets to the remote host // next_seq: u32, } impl TcpConnection { /// Create a new TcpConnection /// /// # Arguments /// /// * `next_seq` - The next sequence number our connection will use for sending packets to the remote host /// * `expected_seq` - The sequence number of the first packet in the next sequence we expect from the remote host pub fn new(next_seq: u32, _expected_seq: u32) -> Self { let buffer = TcpReassemblyBuffer::new(next_seq); // Self { buffer, next_seq } Self { buffer } } } impl Connection for TcpConnection { /// Add data to an existing connection /// /// The given data will be added to the connection buffer. /// /// # Arguments /// /// * `data` - The data to add to the connection /// * `seq` - The sequence number of the data /// /// # Returns /// /// * `bool` - True if the data was added to the connection, false if the connection does not exist fn add_segment(&mut self, data: &[u8], seq: u32) { self.buffer.insert_segment(seq, data.to_vec()); } /// Get the reassembled data for a connection /// /// After processing the reassembled data, clear_buffer should be called before /// processing further packets. /// /// # Arguments /// * `ip_address` - The IP address of the connection /// * `port` - The port of the connection /// /// # Returns /// * `Option>` - The decrypted buffer for the connection, or None if the TCP segment is not yet complete /// or the connection does not exist. fn get_reassembled_buffer(&mut self) -> Option> { self.buffer.reassemble() } /// Get the next sequence number fn next_seq(&self) -> u32 { todo!("Implement retrieving the next sequence number"); } } /// Struct to hold TCP segment data until it can be reassembled. /// /// There might be a more efficient solution built here: https://github.com/weary/reass pub struct TcpReassemblyBuffer { /// Contains the sequence number of the first packet in the sequence. /// /// This is a small optimization to avoid trying to parse the buffer /// before at least the first packet has arrived. first_seq: u32, /// Stores the sequence number and the corresponding data. /// /// A BtreeMap is used vs. a HashMap because it is sorted and we can /// easily iterate over the data in order. buffer: BTreeMap>, /// Expected average size of the reassembled packet. /// /// Used to pre-allocate the buffer when constructing the reassembled packet. expected_size: usize, } impl TcpReassemblyBuffer { /// Create a new TcpReassemblyBuffer struct with sequence number of the first /// expected TCP segment. pub fn new(first_seq: u32) -> Self { Self { buffer: BTreeMap::new(), first_seq, expected_size: 2048, } } /// Insert a segment into the buffer. /// /// The segment is inserted into the buffer at the sequence number. /// /// # Arguments /// /// * `seq` - The sequence number of the segment /// * `data` - The data of the segment pub fn insert_segment(&mut self, seq: u32, data: Vec) { self.buffer.insert(seq, data); } /// Clear the buffer. /// /// Call this when we have processed the previous reassembled packets and want to process further /// packets. /// /// # Arguments /// /// * `first_seq` - The sequence number of the first packet in the next sequence pub fn clear(&mut self, first_seq: u32) { // TODO: I should be able to calculate the first_seq here from the existing buffer? // It would be self.first_seq + self.buffer.len() as u32 + 1 (I think +1?) self.buffer.clear(); self.first_seq = first_seq; } /// Reassemble the data from the buffer. pub fn reassemble(&mut self) -> Option> { // The first packet is not yet received if !self.buffer.contains_key(&self.first_seq) { return None; } let mut reassembled_data = Vec::with_capacity(self.expected_size); let mut expected_seq = self.first_seq; for (&seq, segment) in self.buffer.iter() { if seq != expected_seq { // A segment is still missing return None; } reassembled_data.extend(segment); expected_seq += segment.len() as u32; } Some(reassembled_data) } } ``` ### `scanners/mrpf_engine/src/lib.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/lib.rs` ```rust //! # My Rusty Pension Fund. //! //! This crate provides tools for network scanning and reconnaissance. //! //! It is designed to be: //! - Easy to use //! - Fast //! - Extensible //! - Unreliable (Yes, unreliable. We prefer speed and limited resources over reliability) //! //! Bindings for Python will be available eventually in the MyPensionFund package. //! //! For more information see the [official website](https://www.mortimer.nl/mrpf) //! //! To get started with MRPF, see [`mrpf::Engine`] for more details. mod config; pub mod connections; mod models; pub mod macros; pub mod network; pub mod progress; pub mod ratelimit; pub mod utils; pub use crate::config::ScannerConfig; pub use crate::connections::Connection; pub use crate::connections::tcp::{TcpConnection, TcpConnectionTable, TcpReassemblyBuffer}; pub use models::{Ipv4AddressRange, Ipv4AddressRanges, PortRange, PortRanges}; pub use crate::progress::ProgressHandler; pub use crate::ratelimit::RateLimiter; ``` ### `scanners/mrpf_engine/src/macros.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/macros.rs` ```rust /// This macro is used to unwrap an Option and continue the loop if the Option is None /// Due to the layered approach of the packet processing, we often need to check if the Option is None /// and continue the loop if so. /// /// It is used for instance in the main receive loop as well as the ARP module to retrieve the router MAC address. #[macro_export] macro_rules! unwrap_or_continue { ($option:expr_2021) => { match $option { Some(value) => value, None => continue, } }; } ``` ### `scanners/mrpf_engine/src/models/feistel.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/models/feistel.rs` ```rust /// Generate a random number from an integer range. /// /// Good description of the algorithm: /// https://asim.ihsan.io/random-permutations-with-constant-space/ /// https://github.com/asimihsan/permutation-iterator-rs /// /// # Arguments /// /// * `input` - The input value for the Feistel round. This will be the current iteration through the range /// * `seed` - The seed value used for the Feistel round. /// * `round` - The current round number. /// * `range` - The range of values to generate from. pub(crate) fn feistel_round(input: u128, seed: u128, round: u128, range: u128) -> u128 { let half_range = range / 2; let left = input / half_range; let right = input % half_range; let new_right = (right + (seed ^ round)) % half_range; (left * half_range) + new_right } ``` ### `scanners/mrpf_engine/src/models/ipv4_address_ranges.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/models/ipv4_address_ranges.rs` ```rust use super::Ipv4AddressRange; use super::feistel_round; use ipnet::Ipv4Net; use serde::Deserializer; use serde::de::{self, Deserialize}; use serde::ser::{Serialize, SerializeSeq, Serializer}; use serde_json::Value; use std::collections::HashSet; use std::fmt; use std::net::Ipv4Addr; /// Represent a list of ranges of integer values /// /// This struct will allow us to store a list of ranges of integer values. #[derive(Debug, Clone)] pub struct Ipv4AddressRanges { ranges: Vec, /// The current index of iteration current_index: u32, /// A seed for randomization of the index /// /// Set this to the same value to get a deterministic output. /// This is useful for unit testing or resuming a previous scan seed: u16, /// The amount of rounds te randomizer runs through the /// feistel algorithm. /// /// Higher values will increase the /// amount of randomization. rounds: u16, /// Total length of the ipv4 addresses in the combined ranges /// This can potentially be the full ipv4 address space /// of 4,294,967,296. len: u32, } impl Ipv4AddressRanges { /// Get the ranges stored in the Ipv4Addresses pub fn ranges(&self) -> &Vec { &self.ranges } // Return the total length of ip addresses stored inside the Ipv4AddressRanges combined pub fn len(&self) -> u32 { self.len } /// Check if the list is empty pub fn is_empty(&self) -> bool { self.len == 0 } /// Create a new Ipv4Addresses /// /// This function will create a new Ipv4Addresses. The ranges will be sorted and combined. pub fn new(ranges: Vec) -> Self { let mut int_range_list = Ipv4AddressRanges { ranges, current_index: 0, seed: 42, // TODO: randomize the seed rounds: 10, len: 0, }; int_range_list.sort_and_combine(); int_range_list } /// Create a new Ipv4Addresses with a user provided seed. /// /// This function will create a new Ipv4Addresses. The ranges will be sorted and combined. pub fn new_with_seed(ranges: Vec, seed: u16, rounds: u16) -> Self { let mut int_range_list = Ipv4AddressRanges { ranges, current_index: 0, seed, rounds, len: 0, }; int_range_list.sort_and_combine(); int_range_list } /// Add a range to the list /// /// NOTE: This can be an expensive operation if the list is large, since /// we will need to sort and combine the ranges. pub fn add_range(&mut self, range: Ipv4AddressRange) -> &mut Self { self.ranges.push(range); self.sort_and_combine(); self } /// Add multiple ranges to the list /// /// NOTE: This can be an expensive operation if the list is large, since /// we will need to sort and combine the ranges. pub fn add_ranges(&mut self, ranges: Ipv4AddressRanges) -> &mut Self { self.ranges.extend(ranges.ranges); self.sort_and_combine(); self } /// Add a single number to the list /// /// NOTE: This can be an expensive operation if the list is large, since /// we will need to sort and combine the ranges. pub fn add_u32(&mut self, number: u32) -> &mut Self { self.add_range(Ipv4AddressRange::new(number, number)); self } /// Retrieve a number at a given index /// /// NOTE: Unfortunately, we cannot conform to the Index trait /// as that would require us to return a reference to the number. /// See [this](https://stackoverflow.com/questions/39113649/implementing-index-trait-to-return-a-value-that-is-not-a-reference/79107690?noredirect=1#comment139489811_79107690) /// stackoverflow post for more information. #[allow(clippy::panic)] pub fn index(&self, index: u32) -> u32 { let mut mutable_index = index; for &range in &self.ranges { let range_length = range.end() - range.start() + 1; if mutable_index < range_length { return range.start() + mutable_index; } mutable_index -= range_length; } panic!("Index out of bounds"); } /// Sort and combine the ranges in the list /// /// Note that this can become an expensive operation if the list is large. /// Ranges should preferably be added in sorted order and all at once. fn sort_and_combine(&mut self) { if self.ranges.is_empty() { self.len = 0; return; } self.ranges.sort_by_key(|range| range.start()); let mut new_ranges: Vec = Vec::new(); for range in &self.ranges { if let Some(last) = new_ranges.last_mut() { if last.overlaps_or_adjacent(range) { last.combine(range); } else { new_ranges.push(*range); } } else { new_ranges.push(*range); } } self.ranges = new_ranges; // Calculate and store the length self.len = self .ranges .iter() .fold(0, |acc, range| acc + (range.end() - range.start() + 1)) } } impl Default for Ipv4AddressRanges { fn default() -> Self { Ipv4AddressRanges::new(Vec::new()) } } impl Iterator for Ipv4AddressRanges { type Item = u32; /// Iterate over the Feistel RNG. fn next(&mut self) -> Option { if self.current_index == 0 { // Ensure we have a sorted and combined list of ranges before // we start iterating. This is for instance useful when we // have loaded Ipv4Addresses from JSON. self.sort_and_combine(); } if self.current_index >= self.len { return None; // All numbers have been generated } let mut value = self.current_index as u128; // Avoid divide by zero for feistel_round when self.len == 1 if self.len == 1 { value = 1; } else { for round in 0..self.rounds { value = feistel_round(value, self.seed.into(), round.into(), self.len.into()); } } self.current_index += 1; // Hack to handle the issue with odd self.len numbers. // // The algorithm has a bug where for odd numbers, the last // number it generate, is a duplicate. On top of that, the // final number from the range is always missing. We're // just hardcoded returning the last item from the range here // when that's the case. let computed_index = if self.current_index == self.len && !self.len.is_multiple_of(2) { self.len - 1 } else { (value % self.len as u128) as u32 }; Some(self.index(computed_index)) } } /// Implement display trait impl fmt::Display for Ipv4AddressRanges { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let ranges_str: Vec = self .ranges .iter() .map(|range| format!("{}", range)) .collect(); write!(f, "Ipv4Addresses [{}]", ranges_str.join(", ")) } } /// Cast a u32 into Ipv4Addresses impl From for Ipv4AddressRanges { fn from(value: u32) -> Self { Ipv4AddressRanges::new(vec![Ipv4AddressRange::new(value, value)]) } } /// Cast a str to Ipv4Addresses impl From<&str> for Ipv4AddressRanges { fn from(value: &str) -> Self { Ipv4AddressRanges::new(vec![Ipv4AddressRange::from(value)]) } } /// Cast Ipv4AddressRange iterator into Ipv4Addresses impl FromIterator for Ipv4AddressRanges { fn from_iter>(iter: I) -> Self { let ranges: Vec = iter.into_iter().collect(); Ipv4AddressRanges::new(ranges) } } /// Cast a vector of Ipv4AddressRange to Ipv4Addresses impl From> for Ipv4AddressRanges { fn from(ranges: Vec) -> Self { Ipv4AddressRanges::new(ranges) } } // Cast a Range into Ipv4Addresses impl From> for Ipv4AddressRanges { fn from(range: std::ops::Range) -> Self { Ipv4AddressRanges::new(vec![Ipv4AddressRange::new(range.start, range.end - 1)]) } } // Cast a vector of Range into Ipv4Addresses impl From>> for Ipv4AddressRanges { fn from(ranges: Vec>) -> Self { let int_ranges: Vec = ranges .into_iter() .map(|range| Ipv4AddressRange::new(range.start, range.end - 1)) .collect(); Ipv4AddressRanges::new(int_ranges) } } // Cast a vector of Ipv4Addr into Ipv4Addresses impl From> for Ipv4AddressRanges { fn from(ip_addresses: Vec) -> Self { let int_ranges: Vec = ip_addresses .into_iter() .map(|ip_address| ip_address.into()) .collect(); Ipv4AddressRanges::new(int_ranges) } } // Cast a vector of Ipv4Net into Ipv4Addresses impl From> for Ipv4AddressRanges { fn from(networks: Vec) -> Self { let int_ranges: Vec = networks.into_iter().map(|network| network.into()).collect(); Ipv4AddressRanges::new(int_ranges) } } // Cast a vector of IP address or network strings into Ipv4Addresses impl From> for Ipv4AddressRanges { fn from(networks: Vec<&str>) -> Self { let int_ranges: Vec = networks.into_iter().map(|network| network.into()).collect(); Ipv4AddressRanges::new(int_ranges) } } // Cast a vector of String or network Strings into Ipv4Addresses impl From> for Ipv4AddressRanges { fn from(networks: Vec) -> Self { let int_ranges: Vec = networks .into_iter() .map(|network| network.as_str().into()) .collect(); Ipv4AddressRanges::new(int_ranges) } } // Cast a HashSet of String or network Strings into Ipv4Addresses impl From> for Ipv4AddressRanges { fn from(networks: HashSet) -> Self { let int_ranges: Vec = networks .into_iter() .map(|network| network.as_str().into()) .collect(); Ipv4AddressRanges::new(int_ranges) } } /// Try to convert a serde_json::Value into Ipv4Addresses. /// /// # Examples of Supported Payloads /// /// ## Array of strings (CIDR or IP) /// ```json /// [ "192.168.1.0/24", "10.0.0.1" ] /// ``` /// Result: Ipv4Addresses with two ranges: one for the CIDR, one for the single IP. /// /// ## Array of numbers (u32 IPv4) /// ```json /// [ 3232235777, 167772161 ] /// ``` /// Result: Ipv4Addresses with two ranges: one for each IP address. /// ```json /// [ "192.168.1.0/24", "10.0.0.1" ] /// ``` /// Result: Ipv4Addresses with two ranges: one for the CIDR, one for the single IP. /// /// ## Array of numbers (u32 IPv4) /// ```json /// [ 3232235777, 167772161 ] /// ``` /// Result: Ipv4Addresses with two ranges: one for each IP address. /// /// ## Single string (CIDR or IP) /// ```json /// "192.168.1.0/24" /// ``` /// Result: Ipv4Addresses with one range for the CIDR. /// /// ## Single number (u32 IPv4) /// ```json /// 3232235777 /// ``` /// Result: Ipv4Addresses with one range for the IP address. /// /// Returns an error if the JSON is not a supported array, string, or number, or if any element fails to convert. impl TryFrom for Ipv4AddressRanges { type Error = String; fn try_from(value: serde_json::Value) -> Result { match value { Value::Array(arr) => { let mut ranges = Vec::with_capacity(arr.len()); for v in arr { let range = match v { Value::String(ref s) => Ipv4AddressRange::from(s.as_str()), Value::Number(ref n) if n.is_u64() => { let ip = n.as_u64().unwrap(); if ip <= u32::MAX as u64 { Ipv4AddressRange::from(ip as u32) } else { return Err(format!("Number out of range for IPv4: {}", ip)); } } _ => return Err(format!("Expected string or u32 in array, got {:?}", v)), }; ranges.push(range); } Ok(Ipv4AddressRanges::new(ranges)) } Value::String(s) => Ok(Ipv4AddressRanges::new(vec![Ipv4AddressRange::from( s.as_str(), )])), Value::Number(n) if n.is_u64() => { let ip = n.as_u64().unwrap(); if ip <= u32::MAX as u64 { Ok(Ipv4AddressRanges::new(vec![Ipv4AddressRange::from( ip as u32, )])) } else { Err(format!("Number out of range for IPv4: {}", ip)) } } other => Err(format!( "Expected JSON array, string, or number, got {:?}", other )), } } } impl<'de> Deserialize<'de> for Ipv4AddressRanges { /// Custom deserializer to allow for deserialization of Ipv4Addresses from various JSON formats. /// This leverages the TryFrom implementation for serde_json::Value. fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { // Deserialize into serde_json::Value first let value = serde_json::Value::deserialize(deserializer)?; Ipv4AddressRanges::try_from(value).map_err(de::Error::custom) } } impl Serialize for Ipv4AddressRanges { /// Serialize Ipv4Addresses as a Vec, where each range is represented as a string. /// /// Example output: /// `[ "192.168.1.0-192.168.1.255", "10.0.0.1" ]` fn serialize(&self, serializer: S) -> Result where S: Serializer, { let mut seq = serializer.serialize_seq(Some(self.ranges.len()))?; for range in &self.ranges { seq.serialize_element(&range)?; } seq.end() } } #[cfg(test)] mod tests { use super::*; // Test if the returned value of the index function is memory safe. // I initially used a hack to try and conform it to the Index trait // but the returned value would mutate when a new index was requested. #[test] fn test_memory_safety() { let range_list = Ipv4AddressRanges::new_with_seed(vec![Ipv4AddressRange::new(10, 20)], 42, 10); let i0: &u32 = &range_list.index(0); assert_eq!(*i0, 10); let _ = &range_list.index(5); assert_eq!(*i0, 10); } #[test] fn test_index_within_single_range() { let range_list = Ipv4AddressRanges::new_with_seed(vec![Ipv4AddressRange::new(10, 20)], 42, 10); assert_eq!(range_list.index(0), 10); assert_eq!(range_list.index(5), 15); assert_eq!(range_list.index(10), 20); } #[test] fn test_index_across_multiple_ranges() { let range_list = Ipv4AddressRanges::new_with_seed( vec![Ipv4AddressRange::new(10, 20), Ipv4AddressRange::new(30, 40)], 42, 10, ); assert_eq!(range_list.index(0), 10); assert_eq!(range_list.index(10), 20); assert_eq!(range_list.index(11), 30); assert_eq!(range_list.index(20), 39); assert_eq!(range_list.index(21), 40); } #[test] fn test_range_of_one() { // Ensure the divide by zero error won't return when providing a single range let mut range_list = Ipv4AddressRanges::new_with_seed(vec![Ipv4AddressRange::new(10, 10)], 42, 10); range_list.next(); } #[test] #[should_panic(expected = "Index out of bounds")] fn test_index_out_of_bounds() { let range_list = Ipv4AddressRanges::new_with_seed(vec![Ipv4AddressRange::new(10, 20)], 42, 10); let _ = range_list.index(11); // This should panic } #[test] fn test_intrange_list_iterator_single_range() { let mut range_list = Ipv4AddressRanges::new_with_seed(vec![Ipv4AddressRange::new(10, 15)], 42, 10); assert_eq!(range_list.next(), Some(10)); assert_eq!(range_list.next(), Some(11)); assert_eq!(range_list.next(), Some(12)); assert_eq!(range_list.next(), Some(13)); assert_eq!(range_list.next(), Some(14)); assert_eq!(range_list.next(), Some(15)); assert_eq!(range_list.next(), None); } #[test] fn test_intrange_list_iterator_multiple_ranges() { let range_list = Ipv4AddressRanges::new_with_seed( vec![Ipv4AddressRange::new(10, 12), Ipv4AddressRange::new(20, 22)], 42, 10, ); let mut iter = range_list.into_iter(); assert_eq!(iter.next(), Some(10)); assert_eq!(iter.next(), Some(11)); assert_eq!(iter.next(), Some(12)); assert_eq!(iter.next(), Some(20)); assert_eq!(iter.next(), Some(21)); assert_eq!(iter.next(), Some(22)); assert_eq!(iter.next(), None); } #[test] fn test_len_of_range_list() { let range_list = Ipv4AddressRanges::new_with_seed( vec![Ipv4AddressRange::new(10, 20), Ipv4AddressRange::new(30, 40)], 42, 10, ); assert_eq!(range_list.len, 22); } #[test] fn test_sort_and_combine() { let range_list = Ipv4AddressRanges::new_with_seed( vec![ Ipv4AddressRange::new(30, 40), Ipv4AddressRange::new(10, 20), Ipv4AddressRange::new(15, 25), ], 42, 10, ); assert_eq!(range_list.ranges().len(), 2); assert_eq!(range_list.len, 27); assert_eq!(range_list.ranges()[0].start(), 10); assert_eq!(range_list.ranges()[0].end(), 25); assert_eq!(range_list.ranges()[1].start(), 30); assert_eq!(range_list.ranges()[1].end(), 40); } #[test] fn test_add_range() { let mut range_list = Ipv4AddressRanges::new_with_seed( vec![Ipv4AddressRange::new(10, 20), Ipv4AddressRange::new(30, 40)], 42, 10, ); range_list.add_range(Ipv4AddressRange::new(15, 25)); assert_eq!(range_list.ranges().len(), 2); assert_eq!(range_list.len, 27); assert_eq!(range_list.ranges()[0].start(), 10); assert_eq!(range_list.ranges()[0].end(), 25); assert_eq!(range_list.ranges()[1].start(), 30); assert_eq!(range_list.ranges()[1].end(), 40); } #[test] fn test_add_number() { let mut range_list = Ipv4AddressRanges::new_with_seed( vec![Ipv4AddressRange::new(10, 20), Ipv4AddressRange::new(30, 40)], 42, 10, ); range_list.add_u32(21); assert_eq!(range_list.ranges().len(), 2); assert_eq!(range_list.len, 23); assert_eq!(range_list.ranges()[0].start(), 10); assert_eq!(range_list.ranges()[0].end(), 21); assert_eq!(range_list.ranges()[1].start(), 30); assert_eq!(range_list.ranges()[1].end(), 40); } #[test] fn test_iteration() { let mut range_list = Ipv4AddressRanges::new_with_seed(vec![Ipv4AddressRange::new(10, 20)], 42, 10); assert_eq!(range_list.next(), Some(12)); assert_eq!(range_list.next(), Some(13)); assert_eq!(range_list.next(), Some(14)); assert_eq!(range_list.next(), Some(10)); assert_eq!(range_list.next(), Some(11)); assert_eq!(range_list.next(), Some(17)); assert_eq!(range_list.next(), Some(18)); assert_eq!(range_list.next(), Some(19)); assert_eq!(range_list.next(), Some(15)); assert_eq!(range_list.next(), Some(16)); assert_eq!(range_list.next(), Some(20)); assert_eq!(range_list.next(), None); } #[test] fn test_add_method_chaining() { let mut range_list = Ipv4AddressRanges::new_with_seed(vec![Ipv4AddressRange::new(0, 0)], 42, 10); range_list.add_u32(22).add_u32(21); assert_eq!(range_list.ranges().len(), 2); assert_eq!(range_list.next(), Some(0)); assert_eq!(range_list.next(), Some(21)); assert_eq!(range_list.next(), Some(22)); } #[test] fn test_u16_intrange() { let range_list = Ipv4AddressRanges::new_with_seed( vec![Ipv4AddressRange::new(10, 20), Ipv4AddressRange::new(30, 40)], 42, 10, ); assert_eq!(range_list.len, 22); } } ``` ### `scanners/mrpf_engine/src/models/ipv4_range.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/models/ipv4_range.rs` ```rust use ipnet::Ipv4Net; use serde::de::{self, Deserialize, Deserializer}; use serde::ser::{Serialize, Serializer}; use serde_json::Value; use std::convert::TryFrom; use std::fmt; use std::net::Ipv4Addr; use std::ops::{Range, RangeInclusive}; use std::str::FromStr; /// Represent a single range of integer values #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct Ipv4AddressRange { start: u32, end: u32, current_index: usize, } impl Ipv4AddressRange { /// Start of the range pub fn start(&self) -> u32 { self.start } /// End of the range pub fn end(&self) -> u32 { self.end } /// Combine two ranges pub fn combine(&mut self, other: &Ipv4AddressRange) { self.start = self.start.min(other.start); self.end = self.end.max(other.end); } /// Returns the total length of the range pub fn len(&self) -> u32 { self.end - self.start + 1 } /// Check if the range is empty pub fn is_empty(&self) -> bool { self.start > self.end } /// Create a new integer range. /// /// This function will create a new integer range. If the end value is less than the start value, /// the values will be swapped. pub fn new(start: u32, end: u32) -> Self { Ipv4AddressRange { start: start.min(end), end: start.max(end), current_index: 0, } } pub fn overlaps_or_adjacent(&self, other: &Ipv4AddressRange) -> bool { self.start <= other.end + 1 && self.end + 1 >= other.start } } impl Iterator for Ipv4AddressRange { type Item = u32; fn next(&mut self) -> Option { if self.current_index as u32 <= self.end - self.start { let result = self.start + self.current_index as u32; self.current_index += 1; return Some(result); } None } } /// Display impl fmt::Display for Ipv4AddressRange { /// Display the range in a human-readable format. /// /// For example, "192.168.0.12-192.168.10.123" fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let start_ip = Ipv4Addr::from(self.start); let end_ip = Ipv4Addr::from(self.end); if self.start == self.end { write!(f, "{}", start_ip) } else { write!(f, "{}-{}", start_ip, end_ip) } } } /// Cast a single Ipv4Addr into an Ipv4AddressRange /// /// The start and end value will be the same impl From for Ipv4AddressRange { fn from(ip: Ipv4Addr) -> Self { let ip_as_u32 = u32::from(ip); Ipv4AddressRange::new(ip_as_u32, ip_as_u32) } } /// Cast a single Ipv4Net into an Ipv4AddressRange impl From for Ipv4AddressRange { fn from(network: Ipv4Net) -> Self { let start = network.addr().to_bits(); let mut end = network.broadcast().to_bits(); if end != start { end -= 1; } Ipv4AddressRange::new(start, end) } } /// Cast a single integer as an intrange /// /// The start and end value will be the same impl From for Ipv4AddressRange { fn from(i: u32) -> Self { Ipv4AddressRange::new(i, i) } } /// Cast Range to Ipv4AddressRange impl From> for Ipv4AddressRange { fn from(range: Range) -> Self { Ipv4AddressRange::new(range.start, range.end) } } /// Cast RangeInclusive to Ipv4AddressRange impl From> for Ipv4AddressRange { fn from(range: RangeInclusive) -> Self { let (start, end) = range.into_inner(); Ipv4AddressRange::new(start, end) } } /// Cast Str to Ipv4AddressRange /// /// It supports the following formats: /// - Single IPv4 address: "192.168.1.1" /// - CIDR notation: "192.168.1.0/24" /// - Range notation: "192.168.1.1-192.168.1.10" impl From<&str> for Ipv4AddressRange { fn from(value: &str) -> Self { if let Some((start, end)) = value.split_once('-') { // Range notation: "192.168.1.1-192.168.1.10" let start_ip = Ipv4Addr::from_str(start.trim()) .unwrap_or_else(|_| panic!("Invalid IPv4 address in range: {start}")); let end_ip = Ipv4Addr::from_str(end.trim()) .unwrap_or_else(|_| panic!("Invalid IPv4 address in range: {end}")); Ipv4AddressRange::new(u32::from(start_ip), u32::from(end_ip)) } else if value.contains('/') { Ipv4Net::from_str(value) .unwrap_or_else(|_| panic!("Invalid CIDR notation {value}")) .into() } else { Ipv4Addr::from_str(value) .unwrap_or_else(|_| panic!("Invalid IPv4 address {value}")) .into() } } } /// Implement TryFrom for serde_json::Value to allow for deserialization /// /// This implementation supports deserialization from string representations of either /// a single IPv4 address (e.g., "192.168.1.1") or a CIDR notation (e.g., "192.168.1.0/24"). /// /// It also supports deserialization from a number representing a single IPv4 address /// in its u32 integer form (e.g., 3232235777). impl TryFrom for Ipv4AddressRange { type Error = &'static str; fn try_from(value: serde_json::Value) -> Result { match value { Value::String(s) => Ok(Ipv4AddressRange::from(s.as_str())), Value::Number(num) => { if let Some(n) = num.as_u64() { if n <= u32::MAX as u64 { Ok(Ipv4AddressRange::from(n as u32)) } else { Err("Number out of range for u32") } } else { Err("Invalid number format") } } _ => Err("Only string serialization is supported for Ipv4AddressRange"), } } } impl<'de> Deserialize<'de> for Ipv4AddressRange { /// Custom deserializer to allow for deserialization of Ipv4AddressRange from various JSON formats. /// This leverages the TryFrom implementation for serde_json::Value. fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { // Deserialize into serde_json::Value first let value = serde_json::Value::deserialize(deserializer)?; Ipv4AddressRange::try_from(value).map_err(de::Error::custom) } } impl Serialize for Ipv4AddressRange { /// Custom serializer to convert Ipv4AddressRange into a string representation. /// /// for example "192.168.1.1-192.168.1.10" or "192.168.1.1" for single IPs. fn serialize(&self, serializer: S) -> Result where S: Serializer, { let start_ip = Ipv4Addr::from(self.start); let end_ip = Ipv4Addr::from(self.end); let s = if self.start == self.end { // Single IP start_ip.to_string() } else { // Range format!("{}-{}", start_ip, end_ip) }; serializer.serialize_str(&s) } } #[cfg(test)] mod tests { use super::*; // Test the Ipv4AddressRange struct #[test] fn test_start_and_end_reversed() { let range = Ipv4AddressRange::new(20, 10); assert_eq!(range.start, 10); assert_eq!(range.end, 20); } #[test] fn test_len_of_range() { let range = Ipv4AddressRange::new(10, 20); assert_eq!(range.len(), 11); } #[test] fn test_intrange_iterator() { let range = Ipv4AddressRange::new(10, 15); let mut iter = range.into_iter(); assert_eq!(iter.next(), Some(10)); assert_eq!(iter.next(), Some(11)); assert_eq!(iter.next(), Some(12)); assert_eq!(iter.next(), Some(13)); assert_eq!(iter.next(), Some(14)); assert_eq!(iter.next(), Some(15)); assert_eq!(iter.next(), None); } #[test] fn test_overlaps_or_adjacent() { let range1 = Ipv4AddressRange::new(10, 20); let range2 = Ipv4AddressRange::new(15, 25); let range3 = Ipv4AddressRange::new(21, 21); let range4 = Ipv4AddressRange::new(30, 40); assert!(range1.overlaps_or_adjacent(&range2)); assert!(range2.overlaps_or_adjacent(&range1)); assert!(range1.overlaps_or_adjacent(&range3)); assert!(!range1.overlaps_or_adjacent(&range4)); assert!(!range4.overlaps_or_adjacent(&range2)); } } ``` ### `scanners/mrpf_engine/src/models/mod.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/models/mod.rs` ```rust mod feistel; pub(crate) use feistel::feistel_round; mod ipv4_address_ranges; pub use ipv4_range::Ipv4AddressRange; mod ipv4_range; pub use ipv4_address_ranges::Ipv4AddressRanges; mod port_ranges; pub use port_ranges::PortRanges; mod port_range; pub use port_range::PortRange; ``` ### `scanners/mrpf_engine/src/models/port_range.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/models/port_range.rs` ```rust use serde::de::{self, Deserialize, Deserializer}; use serde::ser::{Serialize, Serializer}; use serde_json::Value; use std::convert::TryFrom; use std::fmt; use std::ops::{Range, RangeInclusive}; use std::str::FromStr; /// Represent a single range of integer values #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct PortRange { start: u16, end: u16, current_index: usize, } impl PortRange { /// Start of the range pub fn start(&self) -> u16 { self.start } /// End of the range pub fn end(&self) -> u16 { self.end } /// Combine two ranges pub fn combine(&mut self, other: &PortRange) { self.start = self.start.min(other.start); self.end = self.end.max(other.end); } /// Returns the total length of the range pub fn len(&self) -> u16 { self.end - self.start + 1 } /// Check if the range is empty pub fn is_empty(&self) -> bool { self.start > self.end } /// Create a new integer range. /// /// This function will create a new integer range. If the end value is less than the start value, /// the values will be swapped. pub fn new(start: u16, end: u16) -> Self { PortRange { start: start.min(end), end: start.max(end), current_index: 0, } } pub fn overlaps_or_adjacent(&self, other: &PortRange) -> bool { self.start <= other.end + 1 && self.end + 1 >= other.start } } impl Iterator for PortRange { type Item = u16; fn next(&mut self) -> Option { if let Ok(current_index_as_t) = u16::try_from(self.current_index) && current_index_as_t <= self.end - self.start { let result = self.start + current_index_as_t; self.current_index += 1; return Some(result); } None } } /// Display impl fmt::Display for PortRange { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.start == self.end { write!(f, "{}", self.start) } else { write!(f, "{}..{}", self.start, self.end) } } } /// Cast a single integer as an PortRange /// /// The start and end value will be the same impl From for PortRange { fn from(i: u16) -> Self { PortRange::new(i, i) } } /// Cast Range to PortRange impl From> for PortRange { fn from(range: Range) -> Self { PortRange::new(range.start, range.end) } } /// Cast RangeInclusive to Ipv4AddressRange impl From> for PortRange { fn from(range: RangeInclusive) -> Self { let (start, end) = range.into_inner(); PortRange::new(start, end) } } // Cast Str as singe port or range of ports. // Either "443" or "1-1024" syntax is supported impl From<&str> for PortRange { fn from(value: &str) -> Self { if let Some(idx) = value.find('-') { let start_str = &value[..idx].trim(); let end_str = &value[idx + 1..].trim(); let start = u16::from_str(start_str).expect("Invalid start port number"); let end = u16::from_str(end_str).expect("Invalid end port number"); PortRange::new(start, end) } else { let port = u16::from_str(value.trim()).expect("Invalid port number"); PortRange::new(port, port) } } } /// Try to construct a PortRange from a serde_json::Value object. /// # Examples of Supported Payloads /// /// ## Object with start and end /// ```json /// { "start": 80, "end": 443 } /// ``` /// Result: PortRange { start: 80, end: 443, current_index: 0 } /// /// ## Single number /// ```json /// 22 /// ``` /// Result: PortRange { start: 22, end: 22, current_index: 0 } /// /// ## String single port /// ```json /// "8080" /// ``` /// Result: PortRange { start: 8080, end: 8080, current_index: 0 } /// /// ## String range /// ```json /// "1000-2000" /// ``` /// Result: PortRange { start: 1000, end: 2000, current_index: 0 } /// /// Returns an error if the fields are missing or not u16. /// ``` impl TryFrom for PortRange { type Error = &'static str; fn try_from(value: serde_json::Value) -> Result { match value { serde_json::Value::Object(map) => { let start = map .get("start") .and_then(|v| v.as_u64()) .ok_or("Missing or invalid 'start' field")?; let end = map .get("end") .and_then(|v| v.as_u64()) .ok_or("Missing or invalid 'end' field")?; if start > u16::MAX as u64 || end > u16::MAX as u64 { return Err("'start' or 'end' out of u16 range"); } Ok(PortRange::new(start as u16, end as u16)) } serde_json::Value::Number(num) => { let port = num.as_u64().ok_or("Invalid number for port")?; if port > u16::MAX as u64 { return Err("Port out of u16 range"); } Ok(PortRange::new(port as u16, port as u16)) } serde_json::Value::String(s) => { // Use From<&str> for string parsing Ok(PortRange::from(s.as_str())) } _ => Err("Unsupported JSON type for PortRange"), } } } impl<'de> Deserialize<'de> for PortRange { /// Custom deserializer to allow for deserialization of Ports from various JSON formats. /// This leverages the TryFrom implementation for serde_json::Value. fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { // Deserialize into serde_json::Value first let value = Value::deserialize(deserializer)?; PortRange::try_from(value).map_err(de::Error::custom) } } impl Serialize for PortRange { /// Custom serializer to serialize PortRange as a json object start and end fields. /// /// For instance { "start": 80, "end": 443 } fn serialize(&self, serializer: S) -> Result where S: Serializer, { use serde::ser::SerializeStruct; let mut state = serializer.serialize_struct("PortRange", 2)?; state.serialize_field("start", &self.start)?; state.serialize_field("end", &self.end)?; state.end() } } #[cfg(test)] mod tests { use super::*; // Test cast PortRange from string to PortRange #[test] fn test_cast_str_to_portrange_u16_single_port() { let range: PortRange = "443".into(); assert_eq!(range.start, 443); assert_eq!(range.end, 443); } #[test] fn test_cast_str_to_portrange_u16_port_range() { let range: PortRange = "1-1024".into(); assert_eq!(range.start, 1); assert_eq!(range.end, 1024); } #[test] #[should_panic] fn test_cast_str_to_portrange_u16_invalid_port() { let _: PortRange = "invalid".into(); } #[test] fn test_cast_str_to_portrange_u16_reversed_range() { let range: PortRange = "1024-1".into(); assert_eq!(range.start, 1); assert_eq!(range.end, 1024); } // Test the PortRange struct with 32 #[test] fn test_start_and_end_reversed() { let range = PortRange::new(20, 10); assert_eq!(range.start, 10); assert_eq!(range.end, 20); } #[test] fn test_len_of_range() { let range = PortRange::new(10, 20); assert_eq!(range.len(), 11); } #[test] fn test_portrange_iterator() { let range = PortRange::new(10, 15); let mut iter = range.into_iter(); assert_eq!(iter.next(), Some(10)); assert_eq!(iter.next(), Some(11)); assert_eq!(iter.next(), Some(12)); assert_eq!(iter.next(), Some(13)); assert_eq!(iter.next(), Some(14)); assert_eq!(iter.next(), Some(15)); assert_eq!(iter.next(), None); } #[test] fn test_overlaps_or_adjacent() { let range1 = PortRange::new(10, 20); let range2 = PortRange::new(15, 25); let range3 = PortRange::new(21, 21); let range4 = PortRange::new(30, 40); assert!(range1.overlaps_or_adjacent(&range2)); assert!(range2.overlaps_or_adjacent(&range1)); assert!(range1.overlaps_or_adjacent(&range3)); assert!(!range1.overlaps_or_adjacent(&range4)); assert!(!range4.overlaps_or_adjacent(&range2)); } // Test the PortRange struct with u16 #[test] fn test_start_and_end_reversed_u16() { let range = PortRange::new(20u16, 10u16); assert_eq!(range.start, 10); assert_eq!(range.end, 20); } } ``` ### `scanners/mrpf_engine/src/models/port_ranges.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/models/port_ranges.rs` ```rust use super::PortRange; use super::feistel_round; use serde::de::{self, Deserialize, Deserializer}; use serde::ser::{Serialize, SerializeSeq, Serializer}; use serde_json::Value; use std::collections::HashSet; use std::fmt; /// Represent a list of ranges of integer values /// /// This struct will allow us to store a list of ranges of integer values. #[derive(Debug, Clone)] pub struct PortRanges { ranges: Vec, /// The current index of iteration current_index: u16, /// A seed for randomization of the index /// /// Set this to the same value to get a deterministic output. /// This is useful for unit testing or resuming a previous scan seed: u16, /// The amount of rounds te randomizer runs through the /// feistel algorithm. /// /// Higher values will increase the /// amount of randomization. rounds: u16, /// Total length of the ports in the combined ranges /// This can potentially be all available ports 65535 len: u16, } impl PortRanges { /// Get the ranges stored in the Ports pub fn ranges(&self) -> &Vec { &self.ranges } // Return the total length of ports stored inside the PortRanges combined pub fn len(&self) -> u16 { self.len } /// Check if the Ports is empty pub fn is_empty(&self) -> bool { self.len == 0 } /// Create a new Ports /// /// This function will create a new Ports. The ranges will be sorted and combined. pub fn new(ranges: Vec) -> Self { let mut int_range_list = PortRanges { ranges, current_index: 0, seed: 42, // TODO: randomize the seed rounds: 10, len: 0, }; int_range_list.sort_and_combine(); int_range_list } /// Create a new Ports with a user provided seed. /// /// This function will create a new Ports. The ranges will be sorted and combined. pub fn new_with_seed(ranges: Vec, seed: u16, rounds: u16) -> Self { let mut int_range_list = PortRanges { ranges, current_index: 0, seed, rounds, len: 0, }; int_range_list.sort_and_combine(); int_range_list } /// Add a range to the list /// /// NOTE: This can be an expensive operation if the list is large, since /// we will need to sort and combine the ranges. pub fn add_range(&mut self, range: PortRange) -> &mut Self { self.ranges.push(range); self.sort_and_combine(); self } /// Add multiple ranges to the list /// /// NOTE: This can be an expensive operation if the list is large, since /// we will need to sort and combine the ranges. pub fn add_ranges(&mut self, ranges: PortRanges) -> &mut Self { self.ranges.extend(ranges.ranges); self.sort_and_combine(); self } /// Add a single number to the list /// /// NOTE: This can be an expensive operation if the list is large, since /// we will need to sort and combine the ranges. pub fn add_port(&mut self, port: u16) -> &mut Self { self.add_range(PortRange::new(port, port)); self } /// Retrieve a number at a given index /// /// NOTE: Unfortunately, we cannot conform to the Index trait /// as that would require us to return a reference to the number. /// See [this](https://stackoverflow.com/questions/39113649/implementing-index-trait-to-return-a-value-that-is-not-a-reference/79107690?noredirect=1#comment139489811_79107690) /// stackoverflow post for more information. pub fn index(&self, index: u16) -> u16 { let mut mutable_index = index; for &range in &self.ranges { let range_length = range.end() - range.start() + 1; if mutable_index < range_length { return range.start() + mutable_index; } mutable_index -= range_length; } panic!("Index out of bounds"); } /// Sort and combine the ranges in the list /// /// Note that this can become an expensive operation if the list is large. /// Ranges should preferably be added in sorted order and all at once. fn sort_and_combine(&mut self) { if self.ranges.is_empty() { self.len = 0; return; } self.ranges.sort_by_key(|range| range.start()); let mut new_ranges: Vec = Vec::new(); for range in &self.ranges { if let Some(last) = new_ranges.last_mut() { if last.overlaps_or_adjacent(range) { last.combine(range); } else { new_ranges.push(*range); } } else { new_ranges.push(*range); } } self.ranges = new_ranges; // Calculate and store the length self.len = self .ranges .iter() .fold(0, |acc, range| acc + (range.end() - range.start() + 1)) } } impl Default for PortRanges { fn default() -> Self { PortRanges::new(Vec::new()) } } impl Iterator for PortRanges { type Item = u16; /// Iterate over the Feistel RNG. fn next(&mut self) -> Option { if self.current_index >= self.len { return None; // All numbers have been generated } let mut value = self.current_index as u128; // Avoid divide by zero for feistel_round when self.len == 1 if self.len == 1 { value = 1; } else { for round in 0..self.rounds { value = feistel_round(value, self.seed.into(), round.into(), self.len.into()); } } self.current_index += 1; // Hack to handle the issue with odd self.len numbers. // // The algorithm has a bug where for odd numbers, the last // number it generate, is a duplicate. On top of that, the // final number from the range is always missing. We're // just hardcoded returning the last item from the range here // if thats the case. let computed_index = if self.current_index == self.len && !self.len.is_multiple_of(2) { self.len - 1 } else { (value % self.len as u128) as u16 }; Some(self.index(computed_index)) } } /// Implement display trait impl fmt::Display for PortRanges { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let ranges_str: Vec = self .ranges .iter() .map(|range| format!("{}", range)) .collect(); write!(f, "Ports [{}]", ranges_str.join(", ")) } } // Cast a into Ports impl From for PortRanges { fn from(value: u16) -> Self { PortRanges::new(vec![PortRange::new(value, value)]) } } /// Cast a into Ports /// /// This is useful when deserializing from formats like JSON, where numeric values /// are often represented as u64 (e.g., serde_json::Value::as_u64). Since there is /// no as_u16 method in serde_json, this allows direct conversion from u64 to Ports, /// panicking if the value does not fit in a u16. impl From for PortRanges { fn from(value: u64) -> Self { let port = u16::try_from(value).expect("Port value out of range for u16"); PortRanges::new(vec![PortRange::new(port, port)]) } } /// Cast PortRange iterator into Ports impl FromIterator for PortRanges { fn from_iter>(iter: I) -> Self { let ranges: Vec = iter.into_iter().collect(); PortRanges::new(ranges) } } // Cast a &str into Ports for ports // Mix and match ports like this: // 1-1024,443,1,8080-8081 impl From<&str> for PortRanges { fn from(s: &str) -> Self { let int_ranges: Vec = s.split(',').map(|range_str| range_str.into()).collect(); PortRanges::new(int_ranges) } } // Cast a String into Ports for ports // Mix and match ports like this: // 1-1024,443,1,8080-8081 impl From for PortRanges { fn from(s: String) -> Self { let int_ranges: Vec = s.split(',').map(|range_str| range_str.into()).collect(); PortRanges::new(int_ranges) } } /// Cast a vector of PortRange to Ports impl From> for PortRanges { fn from(ranges: Vec) -> Self { PortRanges::new(ranges) } } /// Cast a vector of String to Ports impl From> for PortRanges { fn from(ranges: Vec) -> Self { let int_ranges: Vec = ranges.iter().map(|s| s.as_str().into()).collect(); PortRanges::new(int_ranges) } } /// Cast a HashSet of String to Ports impl From> for PortRanges { fn from(ranges: HashSet) -> Self { let int_ranges: Vec = ranges.iter().map(|s| s.as_str().into()).collect(); PortRanges::new(int_ranges) } } /// Cast a vector of &str to Ports impl From> for PortRanges { fn from(ranges: Vec<&str>) -> Self { let int_ranges: Vec = ranges.iter().map(|s| (*s).into()).collect(); PortRanges::new(int_ranges) } } /// Cast a HashSet of &str to Ports impl From> for PortRanges { fn from(ranges: HashSet<&str>) -> Self { let int_ranges: Vec = ranges.iter().map(|s| (*s).into()).collect(); PortRanges::new(int_ranges) } } // Cast a Range into Ports impl From> for PortRanges { fn from(range: std::ops::Range) -> Self { PortRanges::new(vec![PortRange::new(range.start, range.end - 1)]) } } /// Cast a vector of Range into Ports /// /// This is useful when deserializing from formats like JSON, where numeric values /// are often represented as u64 (e.g., serde_json::Value::as_u64). Since there is /// no as_u16 method in serde_json, this allows direct conversion from u64 ranges to Ports, /// panicking if any value does not fit in a u16. impl From>> for PortRanges { fn from(ranges: Vec>) -> Self { let int_ranges: Vec = ranges .into_iter() .map(|range| { let start = u16::try_from(range.start).expect("Port range start out of u16 bounds"); let end = u16::try_from(range.end - 1).expect("Port range end out of u16 bounds"); PortRange::new(start, end) }) .collect(); PortRanges::new(int_ranges) } } /// Try to construct a `Ports` from a `serde_json::Value`. /// /// Supports multiple JSON payload formats: /// /// 1. Array of objects with "start" and "end" fields: /// ```json /// [ /// {"start": 1, "end": 123}, /// {"start": 443, "end": 443} /// ] /// ``` /// /// 2. Array of strings (single ports or ranges): /// ```json /// [ "1-123", "443", "8080-8081" ] /// ``` /// /// 3. Array of integers (single ports): /// ```json /// [1, 443, 8080] /// ``` /// /// Each object is converted to a `PortRange` using its `TryFrom` implementation. /// Returns an error if the JSON is not a supported array, or if any element fails to convert. impl std::convert::TryFrom for PortRanges { type Error = String; fn try_from(value: serde_json::Value) -> Result { match value { Value::Array(arr) => { let mut ranges = Vec::with_capacity(arr.len()); for item in arr { // Try PortRange::try_from first, fallback to From<&str> if it's a string let pr = match &item { Value::Object(_) | Value::Number(_) => PortRange::try_from(item) .map_err(|e| format!("PortRange error: {e}"))?, Value::String(s) => PortRange::from(s.as_str()), _ => return Err(format!("Unsupported item type in ports array: {item}")), }; ranges.push(pr); } Ok(PortRanges::new(ranges)) } other => Err(format!("Expected an array of ports but got {other}")), } } } impl<'de> Deserialize<'de> for PortRanges { /// Custom deserializer to allow for deserialization of Ports from various JSON formats. /// This leverages the TryFrom implementation for serde_json::Value. fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { // Deserialize into serde_json::Value first let value = Value::deserialize(deserializer)?; PortRanges::try_from(value).map_err(de::Error::custom) } } impl Serialize for PortRanges { /// Serialize Ports as a Vec>, where each range is represented as a string. /// /// Example output: /// `[ {"start": 1, "end": 10}, {"start": 12, "end": 30} ]` fn serialize(&self, serializer: S) -> Result where S: Serializer, { let mut seq = serializer.serialize_seq(Some(self.ranges.len()))?; for range in &self.ranges { seq.serialize_element(&range)?; } seq.end() } } #[cfg(test)] mod tests { use super::*; // Test if the returned value of the index function is memory safe. // I initially used a hack to try and conform it to the Index trait // but the returned value would mutate when a new index was requested. #[test] fn test_memory_safety() { let range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 20)], 42, 10); let i0: &u16 = &range_list.index(0); assert_eq!(*i0, 10); let _ = &range_list.index(5); assert_eq!(*i0, 10); } #[test] fn test_casting_port_ranges_from_string() { let range_list: PortRanges = "1-1024,1443,1,8080-8081".into(); assert_eq!(range_list.ranges().len(), 3); assert_eq!(range_list.len(), 1027); assert_eq!(range_list.ranges()[0].start(), 1); assert_eq!(range_list.ranges()[0].end(), 1024); assert_eq!(range_list.ranges()[1].start(), 1443); assert_eq!(range_list.ranges()[1].end(), 1443); assert_eq!(range_list.ranges()[2].start(), 8080); assert_eq!(range_list.ranges()[2].end(), 8081); } #[test] fn test_casting_single_port_from_string() { let range_list: PortRanges = "443".into(); assert_eq!(range_list.ranges().len(), 1); assert_eq!(range_list.len(), 1); assert_eq!(range_list.ranges()[0].start(), 443); assert_eq!(range_list.ranges()[0].end(), 443); } #[test] fn test_index_within_single_range() { let range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 20)], 42, 10); assert_eq!(range_list.index(0), 10); assert_eq!(range_list.index(5), 15); assert_eq!(range_list.index(10), 20); } #[test] fn test_index_across_multiple_ranges() { let range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 20), PortRange::new(30, 40)], 42, 10); assert_eq!(range_list.index(0), 10); assert_eq!(range_list.index(10), 20); assert_eq!(range_list.index(11), 30); assert_eq!(range_list.index(20), 39); assert_eq!(range_list.index(21), 40); } #[test] fn test_range_of_one() { // Ensure the divide by zero error won't return when providing a single range let mut range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 10)], 42, 10); range_list.next(); } #[test] #[should_panic(expected = "Index out of bounds")] fn test_index_out_of_bounds() { let range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 20)], 42, 10); let _ = range_list.index(11); // This should panic } #[test] fn test_intrange_list_iterator_single_range() { let mut range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 15)], 42, 10); assert_eq!(range_list.next(), Some(10)); assert_eq!(range_list.next(), Some(11)); assert_eq!(range_list.next(), Some(12)); assert_eq!(range_list.next(), Some(13)); assert_eq!(range_list.next(), Some(14)); assert_eq!(range_list.next(), Some(15)); assert_eq!(range_list.next(), None); } #[test] fn test_intrange_list_iterator_multiple_ranges() { let range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 12), PortRange::new(20, 22)], 42, 10); let mut iter = range_list.into_iter(); assert_eq!(iter.next(), Some(10)); assert_eq!(iter.next(), Some(11)); assert_eq!(iter.next(), Some(12)); assert_eq!(iter.next(), Some(20)); assert_eq!(iter.next(), Some(21)); assert_eq!(iter.next(), Some(22)); assert_eq!(iter.next(), None); } #[test] fn test_len_of_range_list() { let range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 20), PortRange::new(30, 40)], 42, 10); assert_eq!(range_list.len, 22); } #[test] fn test_sort_and_combine() { let range_list = PortRanges::new_with_seed( vec![ PortRange::new(30, 40), PortRange::new(10, 20), PortRange::new(15, 25), ], 42, 10, ); assert_eq!(range_list.ranges().len(), 2); assert_eq!(range_list.len, 27); assert_eq!(range_list.ranges()[0].start(), 10); assert_eq!(range_list.ranges()[0].end(), 25); assert_eq!(range_list.ranges()[1].start(), 30); assert_eq!(range_list.ranges()[1].end(), 40); } #[test] fn test_add_range() { let mut range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 20), PortRange::new(30, 40)], 42, 10); range_list.add_range(PortRange::new(15, 25)); assert_eq!(range_list.ranges().len(), 2); assert_eq!(range_list.len, 27); assert_eq!(range_list.ranges()[0].start(), 10); assert_eq!(range_list.ranges()[0].end(), 25); assert_eq!(range_list.ranges()[1].start(), 30); assert_eq!(range_list.ranges()[1].end(), 40); } #[test] fn test_add_number() { let mut range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 20), PortRange::new(30, 40)], 42, 10); range_list.add_port(21); assert_eq!(range_list.ranges().len(), 2); assert_eq!(range_list.len, 23); assert_eq!(range_list.ranges()[0].start(), 10); assert_eq!(range_list.ranges()[0].end(), 21); assert_eq!(range_list.ranges()[1].start(), 30); assert_eq!(range_list.ranges()[1].end(), 40); } #[test] fn test_iteration() { let mut range_list = PortRanges::new_with_seed(vec![PortRange::new(10, 20)], 42, 10); assert_eq!(range_list.next(), Some(12)); assert_eq!(range_list.next(), Some(13)); assert_eq!(range_list.next(), Some(14)); assert_eq!(range_list.next(), Some(10)); assert_eq!(range_list.next(), Some(11)); assert_eq!(range_list.next(), Some(17)); assert_eq!(range_list.next(), Some(18)); assert_eq!(range_list.next(), Some(19)); assert_eq!(range_list.next(), Some(15)); assert_eq!(range_list.next(), Some(16)); assert_eq!(range_list.next(), Some(20)); assert_eq!(range_list.next(), None); } #[test] fn test_add_method_chaining() { let mut range_list = PortRanges::new_with_seed(vec![PortRange::new(0, 0)], 42, 10); range_list.add_port(22).add_port(21); assert_eq!(range_list.ranges().len(), 2); assert_eq!(range_list.next(), Some(0)); assert_eq!(range_list.next(), Some(21)); assert_eq!(range_list.next(), Some(22)); } #[test] fn test_u16_intrange() { let range_list = PortRanges::new_with_seed( vec![PortRange::new(10u16, 20u16), PortRange::new(30, 40)], 42, 10, ); assert_eq!(range_list.len, 22); } } ``` ### `scanners/mrpf_engine/src/network/arp.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/arp.rs` ```rust use crate::{ScannerConfig, unwrap_or_continue}; use pnet::datalink::{DataLinkReceiver, DataLinkSender}; use pnet::packet::arp::{ArpHardwareTypes, ArpOperations, ArpPacket, MutableArpPacket}; use pnet::packet::ethernet::{EtherTypes, EthernetPacket, MutableEthernetPacket}; use pnet::packet::{MutablePacket, Packet}; use pnet::util::MacAddr; use std::net::Ipv4Addr; use std::sync::mpsc::Sender; use std::time::{Duration, Instant}; const ARP_REQUEST_SIZE: usize = 60; /// Retrieve router arp address /// /// Note that when the router receives our ARP request, it will also automatically /// store our MAC address in its ARP table. This is useful for ARP/IP spoofing. /// /// # Arguments /// * `src_mac` - The source MAC address /// * `src_ip` - The source IP address /// * `router_ip` - The router IP address /// * `tx` - The data link sender /// * `rx` - The data link receiver /// /// # Returns /// The router MAC address /// /// # Panics /// If the ARP request fails to send or the ARP reply fails to receive /// /// # Example /// /// ```ignore /// use pnet::datalink::{self, NetworkInterface}; /// use pnet::datalink::Channel::Ethernet; /// use pnet::packet::ethernet::{EtherTypes, EthernetPacket}; /// use pnet::packet::arp::{ArpHardwareType, ArpOperations, ArpPacket, MutableArpPacket}; /// use pnet::packet::{MutablePacket, Packet}; /// use pnet::util::MacAddr; /// use std::net::Ipv4Addr; /// use std::sync::mpsc::{self, Receiver, Sender}; /// use std::thread; /// use std::time::Duration; /// /// let interfaces = datalink::interfaces(); /// let interface = interfaces /// .iter() /// .find(|iface| iface.is_up() && !iface.is_loopback() && !iface.ips.is_empty()) /// .expect("Failed to find a network interface"); /// /// let (mut tx, mut rx) = match datalink::channel(&interface, Default::default()) { /// /// Ok(Ethernet(tx, rx)) => (tx, rx), /// Ok(_) => panic!("Unhandled channel type"), /// Err(e) => panic!("Failed to create datalink channel: {}", e), /// }; /// /// let router_ip = Ipv4Addr::new(192, 168, 1, 1); /// let src_mac = interface.mac.unwrap(); /// let src_ip = interface.ips[0].ip(); /// /// let router_mac = get_router_arp_address(src_mac, src_ip, router_ip, &mut tx, &mut rx); /// println!("Router MAC address: {:?}", router_mac); /// ``` pub fn get_router_arp_address( src_mac: MacAddr, src_ip: Ipv4Addr, router_ip: Ipv4Addr, tx: &mut dyn DataLinkSender, rx: &mut dyn DataLinkReceiver, ) -> Option { // Send gratuitous ARP request to announce our spoofed IP address and Mac Address. // It seems it's not always updating the ARP table of the router, so we need // to ensure to do an ARP reply as well. // // UPDATE: This doesn't seem to be really needed, as the router will automatically // update its ARP table when it receives our ARP request? // // let gratuitous_arp_request = create_arp_request(src_mac, src_ip, src_ip)?; // tx.send_to(&gratuitous_arp_request, None) // .unwrap() // .expect("Failed to send gratuitous ARP request for router address"); // Send Arp request to router and wait for reply let arp_request_interval = Duration::from_secs(2); let mut last_request_time = Instant::now() - arp_request_interval; loop { // Send request for routers MAC address every 2 seconds if last_request_time.elapsed() >= arp_request_interval { let arp_request = create_arp_request(src_mac, src_ip, router_ip)?; tx.send_to(&arp_request, None) .unwrap() .expect("Failed to send ARP request for router address"); //println!("ARP request for router address sent") last_request_time = Instant::now(); } // Receive ARP reply from router match rx.next() { Ok(packet) => { let ethernet_packet = unwrap_or_continue!(EthernetPacket::new(packet)); if ethernet_packet.get_ethertype() == EtherTypes::Arp { let arp_packet = unwrap_or_continue!(ArpPacket::new(ethernet_packet.payload())); if arp_packet.get_operation() == ArpOperations::Request && arp_packet.get_target_proto_addr() == src_ip && arp_packet.get_sender_proto_addr() == router_ip { println!("Received implicit Router ARP request: {:?}", arp_packet); let arp_reply = create_arp_reply( src_mac, src_ip, arp_packet.get_sender_hw_addr(), arp_packet.get_sender_proto_addr(), ) .unwrap(); tx.send_to(&arp_reply, None) .unwrap() .expect("Failed to send ARP reply to router"); return Some(arp_packet.get_sender_hw_addr()); } if arp_packet.get_operation() == ArpOperations::Reply && arp_packet.get_sender_proto_addr() == router_ip { // println!("Received explicit ARP reply from router: {:?}", arp_packet); return Some(arp_packet.get_sender_hw_addr()); } } } Err(e) => panic!("Failed to receive router ARP reply: {}", e), } } } /// Create an ARP request packet. /// /// When the src_ip and dst_ip are the same, the ARP request is a gratuitous ARP request. /// This explicitly announces its own IP address to the network. /// /// # Arguments /// * `src_mac` - The source MAC address /// * `src_ip` - The source IP address /// * `dst_ip` - The destination IP address /// /// # Returns /// The ARP request packet /// /// # Example /// /// ```ignore /// use pnet::util::MacAddr; /// use std::net::Ipv4Addr; /// use crate::mrpf::network::arp::create_arp_request; /// /// /// let src_mac = MacAddr::new(0x00, 0x00, 0x00, 0x00, 0x00, 0x00); /// let src_ip = Ipv4Addr::new(192, 168, 1, 182); /// /// let arp_request = create_arp_request(src_mac, src_ip, Ipv4Addr::new(192, 168, 1, 1)); /// ``` #[inline] fn create_arp_request( src_mac: MacAddr, src_ip: Ipv4Addr, dst_ip: Ipv4Addr, ) -> Option<[u8; ARP_REQUEST_SIZE]> { let mut packet = [0u8; ARP_REQUEST_SIZE]; // Ethernet packets have a minimum size of 60 bytes, we need to zero the rest // TODO: can we use a more efficient method to zero the rest of the packet? // We would likely want to reverse the order of packet creation layers packet.fill(0); // Generate Ethernet packet let mut ethernet_packet = MutableEthernetPacket::new(&mut packet).unwrap(); ethernet_packet.set_destination(MacAddr::broadcast()); ethernet_packet.set_source(src_mac); ethernet_packet.set_ethertype(EtherTypes::Arp); // Generate ARP Request Packet let mut arp_packet = MutableArpPacket::new(ethernet_packet.payload_mut()).unwrap(); arp_packet.set_hardware_type(ArpHardwareTypes::Ethernet); arp_packet.set_protocol_type(EtherTypes::Ipv4); arp_packet.set_hw_addr_len(6); arp_packet.set_proto_addr_len(4); arp_packet.set_operation(ArpOperations::Request); arp_packet.set_sender_hw_addr(src_mac); arp_packet.set_sender_proto_addr(src_ip); arp_packet.set_target_hw_addr(MacAddr::broadcast()); arp_packet.set_target_proto_addr(dst_ip); Some(packet) } /// Create an ARP reply packet. /// /// # Arguments /// * `src_mac` - The source MAC address /// * `src_ip` - The source IP address /// /// * `dst_mac` - The destination MAC address /// * `dst_ip` - The destination IP address /// /// # Returns /// The ARP reply packet /// /// # Example /// /// ```ignore /// use pnet::util::MacAddr; /// use std::net::Ipv4Addr; /// use crate::mrpf::network::arp::create_arp_reply; /// /// let src_mac = MacAddr::new(0x00, 0x00, 0x00, 0x00, 0x00, 0x00); /// let src_ip = Ipv4Addr::new(192, 168, 1, 182); /// let dst_mac = MacAddr::new(0x00, 0x00, 0x00, 0x00, 0x00, 0x00); /// let dst_ip = Ipv4Addr::new(192, 168, 1, 1); /// /// let arp_reply = create_arp_reply(src_mac, src_ip, dst_mac, dst_ip); /// ``` #[inline] pub fn create_arp_reply( src_mac: MacAddr, src_ip: Ipv4Addr, dst_mac: MacAddr, dst_ip: Ipv4Addr, ) -> Option<[u8; ARP_REQUEST_SIZE]> { let mut packet = [0u8; ARP_REQUEST_SIZE]; // Generate Ethernet packet let mut ethernet_packet = MutableEthernetPacket::new(&mut packet).unwrap(); ethernet_packet.set_destination(dst_mac); ethernet_packet.set_source(src_mac); ethernet_packet.set_ethertype(EtherTypes::Arp); // Generate ARP Request Packet let mut arp_packet = MutableArpPacket::new(ethernet_packet.payload_mut()).unwrap(); arp_packet.set_hardware_type(ArpHardwareTypes::Ethernet); arp_packet.set_protocol_type(EtherTypes::Ipv4); arp_packet.set_hw_addr_len(6); arp_packet.set_proto_addr_len(4); arp_packet.set_operation(ArpOperations::Reply); arp_packet.set_sender_hw_addr(src_mac); arp_packet.set_sender_proto_addr(src_ip); arp_packet.set_target_hw_addr(dst_mac); arp_packet.set_target_proto_addr(dst_ip); Some(packet) } /// Check if received packet is an ARP request and send an ARP reply. /// /// # Arguments /// * `ethernet_packet` - The Ethernet packet /// * `config` - The scanner configuration /// * `tx` - The data link sender /// /// # Returns /// /// `true` if the packet was an ARP request and an ARP reply was sent, `false` otherwise #[inline] pub fn handle_arp_requests( ethernet_packet: &EthernetPacket, config: &ScannerConfig, packet_tx: &Sender>, ) -> bool { if ethernet_packet.get_ethertype() == EtherTypes::Arp { let result = ArpPacket::new(ethernet_packet.payload()); if let Some(arp_packet) = result && arp_packet.get_target_hw_addr() == config.src_mac { let arp_reply = create_arp_reply( config.src_mac, config.src_ip, arp_packet.get_sender_hw_addr(), arp_packet.get_sender_proto_addr(), ) .unwrap(); match packet_tx.send(arp_reply.to_vec()) { Ok(_) => {} Err(_) => { println!("Failed to send ARP reply packet to sender channel"); return false; } } return true; } } false } ``` ### `scanners/mrpf_engine/src/network/datalink.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/datalink.rs` ```rust use pnet::datalink::{self, Channel::Ethernet, DataLinkReceiver, DataLinkSender, NetworkInterface}; use std::thread::sleep; use std::time::Duration; // Send error codes const NO_BUFFER_SPACE_AVAILABLE_ERROR: Option = Some(55); /// Create a datalink channel for sending and receiving packets. /// /// Note that superuser permissions are required to open a datalink channel. /// /// # Arguments /// /// * `interface` - The network interface to open the channel on. /// * `read_buffer_size` - The size of the read buffer for holding packets. The source code mentions the buffer must be word aligned. /// * `write_buffer_size` - The size of the write buffer holding packets. The source code mentions the buffer must be word aligned. /// /// # Returns /// A tuple containing the sender and receiver for the datalink channel. pub fn create_datalink_channel( interface: &NetworkInterface, read_buffer_size: usize, write_buffer_size: usize, ) -> (Box, Box) { let config = pnet::datalink::Config { write_buffer_size, read_buffer_size, read_timeout: None, write_timeout: None, channel_type: datalink::ChannelType::Layer2, bpf_fd_attempts: 1000, linux_fanout: None, promiscuous: true, socket_fd: None, }; match datalink::channel(interface, config) { Ok(Ethernet(tx, rx)) => (tx, rx), Ok(_) => panic!("Unhandled channel type"), Err(e) => panic!( "Failed to create datalink channel:\n - Error: {}\n - Interface: {:?}", e, interface ), } } /// Send the given packet. /// /// Retry sending the packet if the write buffer is full. /// /// # Arguments /// * `tx` - The data link sender to use. /// * `packet` - The packet to send. /// * `full_write_buffer_count` - The number of times the write buffer was full. /// /// NOTE: this function should be inlined as it is called in a busy loop. #[inline] pub fn send_packet( tx: &mut Box, packet: &[u8], full_write_buffer_count: &mut u8, ) { loop { match tx.send_to(packet, None).unwrap() { Ok(_) => {} Err(e) => { if e.raw_os_error() == NO_BUFFER_SPACE_AVAILABLE_ERROR { *full_write_buffer_count += 1; sleep(Duration::from_millis(100)); continue; } else { panic!("Failed to send packet: {}", e) } } } break; } } ``` ### `scanners/mrpf_engine/src/network/dns.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/dns.rs` ```rust use super::ethernet::create_ethernet_packet; use super::ip::{create_ipv4_packet, generate_flow_cookie_v4}; use super::{PKT_ETH_SIZE, PKT_IPV4_SIZE, PKT_UDP_SIZE}; use crate::ScannerConfig; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::udp::MutableUdpPacket; use pnet::util::MacAddr; use serde::{Deserialize, Serialize}; use std::fmt; use std::net::{Ipv4Addr, Ipv6Addr}; use std::str; /// Common DNS query types we care about for recon. /// Only a subset included for now; extend as needed. #[repr(u16)] #[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum DnsQueryType { A = 1, NS = 2, CNAME = 5, SOA = 6, PTR = 12, MX = 15, TXT = 16, AAAA = 28, SRV = 33, SVCB = 64, // RFC 9460 (formerly type 64) HTTPS = 65, // RFC 9460 URI = 256, CAA = 257, } impl fmt::Display for DnsQueryType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self { DnsQueryType::A => "A", DnsQueryType::NS => "NS", DnsQueryType::CNAME => "CNAME", DnsQueryType::SOA => "SOA", DnsQueryType::PTR => "PTR", DnsQueryType::MX => "MX", DnsQueryType::TXT => "TXT", DnsQueryType::AAAA => "AAAA", DnsQueryType::SRV => "SRV", DnsQueryType::SVCB => "SVCB", DnsQueryType::HTTPS => "HTTPS", DnsQueryType::URI => "URI", DnsQueryType::CAA => "CAA", }; f.write_str(s) } } /// Build a complete Ethernet+IPv4+UDP DNS query packet (traditional UDP, 512 byte max payload). /// /// Returns (transaction_id, packet). /// /// Assumptions: /// - UDP checksum is omitted (set to 0) for performance; most resolvers accept this for IPv4. /// - Caller chooses ephemeral source port. /// - dst_port typically 53. /// - Buffer reserves full 512 bytes after UDP header for DNS payload; unused bytes are left zeroed. #[inline] pub fn create_dns_query_packet_v4( config: &ScannerConfig, router_mac: &MacAddr, src_port: u16, dst_ip: &Ipv4Addr, dst_port: u16, fqdn: &str, query_type: DnsQueryType, ) -> Option<(u16, Vec)> { const DNS_HEADER_STATIC: [u8; 10] = [ 0x01, 0x00, // flags 0x00, 0x01, // QDCOUNT = 1 0x00, 0x00, // ANCOUNT = 0 0x00, 0x00, // NSCOUNT = 0 0x00, 0x00, // ARCOUNT = 0 ]; let flow_cookie = generate_flow_cookie_v4(&config.src_ip, src_port, dst_ip, dst_port, config.entropy); let tx_id: u16 = (flow_cookie & 0xFFFF) as u16; let mut dns_payload: Vec = Vec::with_capacity(12 + fqdn.len() + 6); dns_payload.extend_from_slice(&tx_id.to_be_bytes()); dns_payload.extend_from_slice(&DNS_HEADER_STATIC); for label in fqdn.split('.') { let bytes = label.as_bytes(); if bytes.is_empty() || bytes.len() > 63 { return None; } dns_payload.push(bytes.len() as u8); dns_payload.extend_from_slice(bytes); } dns_payload.push(0); // terminator dns_payload.extend_from_slice(&(query_type as u16).to_be_bytes()); dns_payload.extend_from_slice(&1u16.to_be_bytes()); // QCLASS IN let dns_len = dns_payload.len(); let total_ipv4_len = (PKT_IPV4_SIZE + PKT_UDP_SIZE + dns_len) as u16; let total_packet_len = PKT_ETH_SIZE + total_ipv4_len as usize; let mut packet = vec![0u8; total_packet_len]; let ipv4_offset = PKT_ETH_SIZE; let udp_offset = PKT_ETH_SIZE + PKT_IPV4_SIZE; let dns_offset = udp_offset + PKT_UDP_SIZE; create_ethernet_packet(&config.src_mac, router_mac, &mut packet)?; create_ipv4_packet( &config.src_ip, dst_ip, total_ipv4_len, Some(IpNextHeaderProtocols::Udp), &mut packet[ipv4_offset..], )?; { let mut udp_packet = MutableUdpPacket::new(&mut packet[udp_offset..udp_offset + PKT_UDP_SIZE + dns_len])?; udp_packet.set_source(src_port); udp_packet.set_destination(dst_port); udp_packet.set_length((PKT_UDP_SIZE + dns_len) as u16); udp_packet.set_checksum(0); } packet[dns_offset..dns_offset + dns_len].copy_from_slice(&dns_payload); Some((tx_id, packet)) } #[derive(Debug, Clone)] pub enum DnsResourceRecord { A { owner_name: String, addr: Ipv4Addr, }, AAAA { owner_name: String, addr: Ipv6Addr, }, CNAME { owner_name: String, cname: String, }, MX { owner_name: String, preference: u16, exchange: String, }, SOA { owner_name: String, mname: String, rname: String, serial: u32, refresh: u32, retry: u32, expire: u32, minimum: u32, }, TXT { owner_name: String, texts: Vec, }, CAA { owner_name: String, flags: u8, tag: String, value: String, }, HTTPS { owner_name: String, priority: u16, target: String, params: Vec<(u16, Vec)>, }, PTR { owner_name: String, ptrdname: String, }, URI { owner_name: String, priority: u16, weight: u16, target: String, }, Unknown { owner_name: String, rr_type: u16, data: Vec, }, } impl DnsResourceRecord { /// Get the owner name of the resource record. pub fn owner_name(&self) -> &str { match self { DnsResourceRecord::A { owner_name, .. } => owner_name, DnsResourceRecord::AAAA { owner_name, .. } => owner_name, DnsResourceRecord::CNAME { owner_name, .. } => owner_name, DnsResourceRecord::MX { owner_name, .. } => owner_name, DnsResourceRecord::SOA { owner_name, .. } => owner_name, DnsResourceRecord::TXT { owner_name, .. } => owner_name, DnsResourceRecord::CAA { owner_name, .. } => owner_name, DnsResourceRecord::HTTPS { owner_name, .. } => owner_name, DnsResourceRecord::PTR { owner_name, .. } => owner_name, DnsResourceRecord::URI { owner_name, .. } => owner_name, DnsResourceRecord::Unknown { owner_name, .. } => owner_name, } } } #[inline(always)] fn read_u16(buf: &[u8], idx: usize) -> Option { if idx + 2 > buf.len() { None } else { Some(u16::from_be_bytes([buf[idx], buf[idx + 1]])) } } #[inline(always)] fn read_u32(buf: &[u8], idx: usize) -> Option { if idx + 4 > buf.len() { None } else { Some(u32::from_be_bytes([ buf[idx], buf[idx + 1], buf[idx + 2], buf[idx + 3], ])) } } /// Decode a domain name starting at `offset` using DNS label compression. /// Returns (name, bytes_consumed_for_uncompressed_part). fn decode_name(buf: &[u8], mut offset: usize) -> Option<(String, usize)> { let _start = offset; // kept for potential future reference when debugging name offsets if offset >= buf.len() { return None; } let mut labels: Vec = Vec::new(); // To avoid infinite loops on malicious packets track jumps. let mut jumps = 0; let mut consumed = 0; let mut jumped = false; loop { if offset >= buf.len() { return None; } let len = buf[offset]; if len & 0xC0 == 0xC0 { // compression pointer if offset + 1 >= buf.len() { return None; } let ptr = (((len as u16) & 0x3F) << 8) | buf[offset + 1] as u16; if ptr as usize >= buf.len() { return None; } if jumps > 10 { // arbitrary safety cap return None; } jumps += 1; if !jumped { consumed += 2; // only count initial pointer bytes towards consumed } offset = ptr as usize; jumped = true; continue; } else if len == 0 { // end if !jumped { consumed += 1; } break; } else { let label_len = len as usize; let label_start = offset + 1; let label_end = label_start + label_len; if label_end > buf.len() { return None; } let label_bytes = &buf[label_start..label_end]; if let Ok(label_str) = str::from_utf8(label_bytes) { labels.push(label_str.to_string()); } else { return None; // non-UTF8 label not expected in typical recon } if !jumped { consumed += 1 + label_len; } offset = label_end; } } let name = if labels.is_empty() { String::from(".") } else { labels.join(".") }; Some((name, consumed)) } fn parse_txt(r: &[u8]) -> Vec { let mut out = Vec::new(); let mut i = 0; while i < r.len() { let l = r[i] as usize; i += 1; if i + l > r.len() { break; } if let Ok(s) = str::from_utf8(&r[i..i + l]) { out.push(s.to_string()); } i += l; } out } fn parse_caa(owner: &str, r: &[u8]) -> Option { if r.len() < 2 { return None; } let flags = r[0]; let tag_len = r[1] as usize; if 2 + tag_len > r.len() { return None; } let tag = str::from_utf8(&r[2..2 + tag_len]).ok()?.to_string(); let value = str::from_utf8(&r[2 + tag_len..]).unwrap_or("").to_string(); Some(DnsResourceRecord::CAA { owner_name: owner.to_string(), flags, tag, value, }) } fn parse_https(owner: &str, buf: &[u8], msg: &[u8], r_off: usize) -> Option { if buf.len() < 2 { return None; } let priority = u16::from_be_bytes([buf[0], buf[1]]); let mut idx = 2; let name_off = r_off + idx; let (target, cons) = decode_name(msg, name_off)?; idx += cons; if idx > buf.len() { return None; } let mut params = Vec::new(); while idx + 4 <= buf.len() { let key = u16::from_be_bytes([buf[idx], buf[idx + 1]]); let len = u16::from_be_bytes([buf[idx + 2], buf[idx + 3]]) as usize; idx += 4; if idx + len > buf.len() { break; } params.push((key, buf[idx..idx + len].to_vec())); idx += len; } Some(DnsResourceRecord::HTTPS { owner_name: owner.to_string(), priority, target, params, }) } fn parse_mx(owner: &str, msg: &[u8], r: &[u8], r_off: usize) -> Option { if r.len() < 2 { return None; } let pref = u16::from_be_bytes([r[0], r[1]]); let (ex, _) = decode_name(msg, r_off + 2)?; Some(DnsResourceRecord::MX { owner_name: owner.to_string(), preference: pref, exchange: ex, }) } fn parse_soa(owner: &str, msg: &[u8], r_off: usize) -> Option { let (mname, mc) = decode_name(msg, r_off)?; let (rname, rc) = decode_name(msg, r_off + mc)?; let base = r_off + mc + rc; let serial = read_u32(msg, base)?; let refresh = read_u32(msg, base + 4)?; let retry = read_u32(msg, base + 8)?; let expire = read_u32(msg, base + 12)?; let minimum = read_u32(msg, base + 16)?; Some(DnsResourceRecord::SOA { owner_name: owner.to_string(), mname, rname, serial, refresh, retry, expire, minimum, }) } fn parse_uri(owner: &str, r: &[u8]) -> Option { if r.len() < 4 { return None; } let priority = u16::from_be_bytes([r[0], r[1]]); let weight = u16::from_be_bytes([r[2], r[3]]); let target = str::from_utf8(&r[4..]).ok()?.to_string(); Some(DnsResourceRecord::URI { owner_name: owner.to_string(), priority, weight, target, }) } /// Parse a DNS response payload (UDP payload beginning at DNS header) into resource records. /// Returns all answer + authority + additional records as a flat Vec. /// Invalid packets yield an empty Vec. pub fn parse_dns_response(payload: &[u8]) -> Option> { const HDR: usize = 12; if payload.len() < HDR { return None; } let flags = read_u16(payload, 2)?; if (flags & 0x8000) == 0 { return None; } let qd = read_u16(payload, 4).unwrap_or(0) as usize; let an = read_u16(payload, 6).unwrap_or(0) as usize; let ns = read_u16(payload, 8).unwrap_or(0) as usize; let ar = read_u16(payload, 10).unwrap_or(0) as usize; let mut offset = HDR; for _ in 0..qd { let (_n, cons) = decode_name(payload, offset)?; offset += cons + 4; if offset > payload.len() { return None; } } let total = an + ns + ar; let mut out = Vec::with_capacity(total); for _ in 0..total { if offset >= payload.len() { break; } let (owner, cons) = match decode_name(payload, offset) { Some(v) => v, None => break, }; offset += cons; if offset + 10 > payload.len() { break; } let rr_type = read_u16(payload, offset).unwrap(); offset += 8; let rdlen = read_u16(payload, offset).unwrap_or(0) as usize; offset += 2; if offset + rdlen > payload.len() { break; } let r_start = offset; let r_slice = &payload[offset..offset + rdlen]; let rec = match rr_type { 1 => { if rdlen == 4 { DnsResourceRecord::A { owner_name: owner.clone(), addr: Ipv4Addr::new(r_slice[0], r_slice[1], r_slice[2], r_slice[3]), } } else { DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), } } } 28 => { if rdlen == 16 { let mut oct = [0u8; 16]; oct.copy_from_slice(r_slice); DnsResourceRecord::AAAA { owner_name: owner.clone(), addr: Ipv6Addr::from(oct), } } else { DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), } } } 5 => match decode_name(payload, r_start) { Some((c, _)) => DnsResourceRecord::CNAME { owner_name: owner.clone(), cname: c, }, None => DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), }, }, 15 => match parse_mx(&owner, payload, r_slice, r_start) { Some(mx) => mx, None => DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), }, }, 6 => match parse_soa(&owner, payload, r_start) { Some(soa) => soa, None => DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), }, }, 16 => DnsResourceRecord::TXT { owner_name: owner.clone(), texts: parse_txt(r_slice), }, 257 => match parse_caa(&owner, r_slice) { Some(caa) => caa, None => DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), }, }, 65 => match parse_https(&owner, r_slice, payload, r_start) { Some(h) => h, None => DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), }, }, 12 => match decode_name(payload, r_start) { Some((p, _)) => DnsResourceRecord::PTR { owner_name: owner.clone(), ptrdname: p, }, None => DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), }, }, 256 => match parse_uri(&owner, r_slice) { Some(uri) => uri, None => DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), }, }, _ => DnsResourceRecord::Unknown { owner_name: owner.clone(), rr_type, data: r_slice.to_vec(), }, }; offset += rdlen; out.push(rec); } Some(out) } /// Transform an input domain or IP string into the correct reverse lookup (PTR query) name. /// /// Supports: /// - IPv4 (e.g. 1.2.3.4 -> 4.3.2.1.in-addr.arpa.) /// - IPv6 (full/short) by expanding, nibble reversing -> .ip6.arpa. /// - Already reversed names (ends with .in-addr.arpa. or .ip6.arpa.) returned as-is (normalized to trailing dot). /// /// Returns None for invalid input. pub fn domain_name_to_ptr(input: &str) -> Option { let trimmed = input.trim_end_matches('.'); if trimmed.is_empty() { return None; } // If already looks like a reverse domain, normalize by ensuring trailing dot. if trimmed.ends_with("in-addr.arpa") || trimmed.ends_with("ip6.arpa") { return Some(format!("{}.", trimmed)); } // Try IPv4 first (fast path, most common) if let Ok(v4) = trimmed.parse::() { let octets = v4.octets(); return Some(format!( "{}.{}.{}.{}.in-addr.arpa.", octets[3], octets[2], octets[1], octets[0] )); } // Try IPv6: need to expand to full 32 hex nibbles then reverse nibble order. if let Ok(v6) = trimmed.parse::() { let segments = v6.segments(); // 8 u16 segments // Convert segments to 4 hex chars each (zero padded) -> 32 chars let mut full_hex = String::with_capacity(32); for seg in &segments { full_hex.push_str(&format!("{:04x}", seg)); } // Reverse nibble order with dots between each nibble let mut out = String::with_capacity((32 * 2) + 10); // 32 nibbles + dots + suffix for ch in full_hex.chars().rev() { out.push(ch); out.push('.'); } out.push_str("ip6.arpa."); return Some(out); } // Fallback: treat as plain hostname - cannot convert to PTR without IP None } #[cfg(test)] mod tests { use super::domain_name_to_ptr; #[test] fn ipv4_to_ptr() { assert_eq!( domain_name_to_ptr("1.2.3.4"), Some("4.3.2.1.in-addr.arpa.".to_string()) ); } #[test] fn ipv6_to_ptr() { // 2001:db8::1 -> expanded 2001:0db8:0000:0000:0000:0000:0000:0001 let expected_prefix = "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa."; assert_eq!( domain_name_to_ptr("2001:db8::1"), Some(expected_prefix.to_string()) ); } #[test] fn already_reversed_ipv4() { assert_eq!( domain_name_to_ptr("4.3.2.1.in-addr.arpa"), Some("4.3.2.1.in-addr.arpa.".to_string()) ); } #[test] fn already_reversed_ipv6() { let name = "1.0.0.0.ip6.arpa"; // truncated example assert_eq!( domain_name_to_ptr(name), Some("1.0.0.0.ip6.arpa.".to_string()) ); } #[test] fn invalid_hostname() { assert_eq!(domain_name_to_ptr("example.com"), None); // not an IP } #[test] fn empty_input() { assert_eq!(domain_name_to_ptr(""), None); } } ``` ### `scanners/mrpf_engine/src/network/ethernet.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/ethernet.rs` ```rust use pnet::packet::ethernet::{EtherTypes, MutableEthernetPacket}; use pnet::util::MacAddr; /// Create an Ethernet packet #[inline] pub(crate) fn create_ethernet_packet<'a>( src_mac: &MacAddr, router_mac: &MacAddr, packet: &'a mut [u8], ) -> Option> { let mut ethernet_packet = MutableEthernetPacket::new(packet)?; ethernet_packet.set_ethertype(EtherTypes::Ipv4); ethernet_packet.set_source(*src_mac); ethernet_packet.set_destination(*router_mac); Some(ethernet_packet) } ``` ### `scanners/mrpf_engine/src/network/http.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/http.rs` ```rust ``` ### `scanners/mrpf_engine/src/network/ip.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/ip.rs` ```rust use pnet::packet::ip::IpNextHeaderProtocol; use pnet::packet::ipv4::MutableIpv4Packet; #[cfg(feature = "siphasher")] use siphasher::sip::SipHasher24; #[cfg(feature = "siphasher")] use std::hash::Hasher; use std::net::Ipv4Addr; /// Generate a stable per-flow cookie (32-bit) for correlating packets without state. /// /// Based on original TCP SYN cookie logic; generalized so other protocols (e.g. DNS transaction IDs) /// can derive IDs by truncating or mixing this value. #[inline] pub fn generate_flow_cookie_v4( src_ip: &Ipv4Addr, src_port: u16, dst_ip: &Ipv4Addr, dst_port: u16, entropy: u64, ) -> u32 { #[cfg(feature = "siphasher")] let cookie = { let mut hasher = SipHasher24::new_with_keys(entropy, entropy); hasher.write_u32(src_ip.to_bits()); hasher.write_u16(src_port); hasher.write_u32(dst_ip.to_bits()); hasher.write_u16(dst_port); hasher.finish() as u32 }; #[cfg(not(feature = "siphasher"))] let cookie = { let mut data = [0u8; 12]; data[..4].copy_from_slice(&src_ip.octets()); data[4..6].copy_from_slice(&src_port.to_be_bytes()); data[6..10].copy_from_slice(&dst_ip.octets()); data[10..12].copy_from_slice(&dst_port.to_be_bytes()); gxhash::gxhash32(&data, entropy as i64) }; cookie } #[inline] pub(crate) fn create_ipv4_packet<'a>( src_ip: &Ipv4Addr, dst_ip: &Ipv4Addr, size: u16, next_protocol: Option, packet: &'a mut [u8], ) -> Option> { let mut ipv4_packet = MutableIpv4Packet::new(packet)?; ipv4_packet.set_version(4); ipv4_packet.set_header_length(5); ipv4_packet.set_total_length(size); ipv4_packet.set_ttl(64); ipv4_packet.set_source(*src_ip); ipv4_packet.set_destination(*dst_ip); ipv4_packet.set_flags(0x2); // Don't fragment if let Some(next_protocol) = next_protocol { ipv4_packet.set_next_level_protocol(next_protocol); } // Calculate the checksum ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable())); Some(ipv4_packet) } ``` ### `scanners/mrpf_engine/src/network/mod.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/mod.rs` ```rust pub mod arp; pub mod datalink; pub mod dns; pub mod ethernet; pub mod ip; pub mod tcp; pub mod tls; use pnet::packet::ethernet::EthernetPacket; use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::tcp::TcpPacket; use pnet::packet::udp::UdpPacket; // Minimal packet sizes const PKT_ETH_SIZE: usize = EthernetPacket::minimum_packet_size(); const PKT_IPV4_SIZE: usize = Ipv4Packet::minimum_packet_size(); const PKT_TCP_SIZE: usize = TcpPacket::minimum_packet_size(); const PKT_UDP_SIZE: usize = UdpPacket::minimum_packet_size(); // Packet offsets for each layer when using datalink layer const PKT_IPV4_OFFSET: usize = PKT_ETH_SIZE; const PKT_TCP_OFFSET: usize = PKT_ETH_SIZE + PKT_IPV4_SIZE; ``` ### `scanners/mrpf_engine/src/network/tcp.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/tcp.rs` ```rust use super::ethernet::create_ethernet_packet; use super::ip::{create_ipv4_packet, generate_flow_cookie_v4}; use super::{PKT_ETH_SIZE, PKT_IPV4_OFFSET, PKT_IPV4_SIZE, PKT_TCP_OFFSET, PKT_TCP_SIZE}; use crate::ScannerConfig; use pnet::packet::Packet; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::tcp::MutableTcpPacket; use pnet::packet::tcp::TcpFlags::{ACK, RST, SYN}; use pnet::packet::tcp::TcpOption; use pnet::util::MacAddr; use std::net::Ipv4Addr; const PKT_TCP_OPTION_MSS_SIZE: usize = 4; #[inline] pub fn create_syn_tcp_packet_v4( config: &ScannerConfig, router_mac: &MacAddr, src_port: u16, dst_ip: &Ipv4Addr, dst_port: u16, ) -> Option<( u32, // syn cookie [u8; PKT_ETH_SIZE + PKT_IPV4_SIZE + PKT_TCP_SIZE + PKT_TCP_OPTION_MSS_SIZE], )> { let mut packet = [0u8; PKT_ETH_SIZE + PKT_IPV4_SIZE + PKT_TCP_SIZE + PKT_TCP_OPTION_MSS_SIZE]; // Optionally, reset the packet to zeros at any point packet.fill(0); // Create an Ethernet packet create_ethernet_packet(&config.src_mac, router_mac, &mut packet)?; // Create an IPv4 packet create_ipv4_packet( &config.src_ip, dst_ip, (PKT_IPV4_SIZE + PKT_TCP_SIZE + PKT_TCP_OPTION_MSS_SIZE) as u16, Some(IpNextHeaderProtocols::Tcp), &mut packet[PKT_IPV4_OFFSET..], )?; // Generate a SYN cookie. // // This is a 32-bit value that is used to generate the TCP sequence number. // It is used to identify the connection on the receive thread and // generated by hashing the source and destination ip addresses and ports. let syn_cookie = generate_flow_cookie_v4(&config.src_ip, src_port, dst_ip, dst_port, config.entropy); // Create a TCP packet { let mut tcp_packet = MutableTcpPacket::new(&mut packet[PKT_TCP_OFFSET..])?; tcp_packet.set_source(src_port); tcp_packet.set_destination(dst_port); tcp_packet.set_sequence(syn_cookie); tcp_packet.set_acknowledgement(0); tcp_packet.set_window(65535); tcp_packet.set_flags(SYN); // 5 (20 bytes header) + 1 (4 bytes options) adds up to 24 bytes // which is a multitude of 32 bits as per the TCP RFC. tcp_packet.set_data_offset(6); tcp_packet.set_options(&[TcpOption::mss(1460)]); let checksum = pnet::util::ipv4_checksum( tcp_packet.packet(), 8, &[], &config.src_ip, dst_ip, IpNextHeaderProtocols::Tcp, ); tcp_packet.set_checksum(checksum); } Some((syn_cookie, packet)) } #[inline] pub fn create_tcp_rst( src_mac: &MacAddr, router_mac: &MacAddr, src_ip: &Ipv4Addr, src_port: u16, dst_ip: &Ipv4Addr, dst_port: u16, sequence: u32, ) -> Option<[u8; PKT_ETH_SIZE + PKT_IPV4_SIZE + PKT_TCP_SIZE]> { let mut packet = [0u8; PKT_ETH_SIZE + PKT_IPV4_SIZE + PKT_TCP_SIZE]; // Optionally, reset the packet to zeros at any point packet.fill(0); // Create an Ethernet packet create_ethernet_packet(src_mac, router_mac, &mut packet)?; // Create an IPv4 packet create_ipv4_packet( src_ip, dst_ip, (PKT_IPV4_SIZE + PKT_TCP_SIZE) as u16, Some(IpNextHeaderProtocols::Tcp), &mut packet[PKT_IPV4_OFFSET..], )?; // Create TCP RST packet { let mut tcp_packet = MutableTcpPacket::new(&mut packet[PKT_TCP_OFFSET..])?; tcp_packet.set_source(src_port); tcp_packet.set_destination(dst_port); tcp_packet.set_sequence(sequence); // tcp_packet.set_acknowledgement(acknowledgement); // RST packets should not set an ACK tcp_packet.set_window(65535); tcp_packet.set_flags(RST); tcp_packet.set_data_offset(5); let checksum = pnet::util::ipv4_checksum( tcp_packet.packet(), 8, &[], src_ip, dst_ip, IpNextHeaderProtocols::Tcp, ); tcp_packet.set_checksum(checksum); } Some(packet) } #[inline] #[expect(clippy::too_many_arguments)] pub fn create_tcp_ack( src_mac: &MacAddr, router_mac: &MacAddr, src_ip: &Ipv4Addr, src_port: u16, dst_ip: &Ipv4Addr, dst_port: u16, sequence: u32, acknowledgement: u32, ) -> Option<[u8; PKT_ETH_SIZE + PKT_IPV4_SIZE + PKT_TCP_SIZE]> { let mut packet = [0u8; PKT_ETH_SIZE + PKT_IPV4_SIZE + PKT_TCP_SIZE]; // Optionally, reset the packet to zeros at any point packet.fill(0); // Create an Ethernet packet create_ethernet_packet(src_mac, router_mac, &mut packet)?; // Create an IPv4 packet create_ipv4_packet( src_ip, dst_ip, (PKT_IPV4_SIZE + PKT_TCP_SIZE) as u16, Some(IpNextHeaderProtocols::Tcp), &mut packet[PKT_IPV4_OFFSET..], )?; // Create TCP ACK packet { let mut tcp_packet = MutableTcpPacket::new(&mut packet[PKT_TCP_OFFSET..])?; tcp_packet.set_source(src_port); tcp_packet.set_destination(dst_port); tcp_packet.set_sequence(sequence); tcp_packet.set_acknowledgement(acknowledgement); tcp_packet.set_window(65535); tcp_packet.set_flags(ACK); tcp_packet.set_data_offset(5); let checksum = pnet::util::ipv4_checksum( tcp_packet.packet(), 8, &[], src_ip, dst_ip, IpNextHeaderProtocols::Tcp, ); tcp_packet.set_checksum(checksum); } Some(packet) } #[inline] #[expect(clippy::too_many_arguments)] pub fn create_tcp_ack_with_data( src_mac: &MacAddr, router_mac: &MacAddr, src_ip: &Ipv4Addr, src_port: u16, dst_ip: &Ipv4Addr, dst_port: u16, sequence: u32, acknowledgement: u32, payload: &[u8], ) -> Option> { let packet_length = PKT_ETH_SIZE + PKT_IPV4_SIZE + PKT_TCP_SIZE + payload.len(); let mut packet = vec![0u8; packet_length]; // Optionally, reset the packet to zeros at any point packet.fill(0); // Create an Ethernet packet create_ethernet_packet(src_mac, router_mac, &mut packet)?; // Create an IPv4 packet create_ipv4_packet( src_ip, dst_ip, (PKT_IPV4_SIZE + PKT_TCP_SIZE + payload.len()) as u16, Some(IpNextHeaderProtocols::Tcp), &mut packet[PKT_IPV4_OFFSET..], )?; // Create TCP ACK packet { let mut tcp_packet = MutableTcpPacket::new(&mut packet[PKT_TCP_OFFSET..])?; tcp_packet.set_source(src_port); tcp_packet.set_destination(dst_port); tcp_packet.set_sequence(sequence); tcp_packet.set_acknowledgement(acknowledgement); tcp_packet.set_window(65535); tcp_packet.set_flags(ACK); tcp_packet.set_data_offset(5); tcp_packet.set_payload(payload); let checksum = pnet::util::ipv4_checksum( tcp_packet.packet(), 8, &[], src_ip, dst_ip, IpNextHeaderProtocols::Tcp, ); tcp_packet.set_checksum(checksum); } Some(packet) } ``` ### `scanners/mrpf_engine/src/network/tls.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/network/tls.rs` ```rust use super::ethernet::create_ethernet_packet; use super::ip::create_ipv4_packet; use super::{PKT_ETH_SIZE, PKT_IPV4_OFFSET, PKT_IPV4_SIZE, PKT_TCP_OFFSET, PKT_TCP_SIZE}; use bytes::BytesMut; use mrpf_cert_parser::{parse_tlv, tags::TAG_SEQUENCE}; use pnet::packet::Packet; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::tcp::TcpFlags::ACK; use pnet::util::MacAddr; use std::fmt; use std::net::Ipv4Addr; // This const is a static version of a TLS 1.2 Client Hello message. It // does not include the handshake header and record header as these need // to be computed after we add the optional SNI extension. const TLS_12_CLIENT_HELLO: [u8; 71] = [ 0x03, 0x03, // TLS version 1.2 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, 0x13, 0x37, // Random bytes 0x00, // Session ID length (0) 0x00, 0x20, // Length of the cipher suites (32 bytes) 0x5a, 0x5a, // Reserved grease 0x13, 0x01, // TLS_AES_128_GCM_SHA256 0x13, 0x02, // TLS_AES_256_GCM_SHA384 0x13, 0x03, // TLS_CHACHA20_POLY1305_SHA256 0xc0, 0x2b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 0xc0, 0x2f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 0xc0, 0x2c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 0xc0, 0x30, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 0xcc, 0xa9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 0xcc, 0xa8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 0xc0, 0x13, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA 0xc0, 0x14, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA 0x00, 0x9c, // TLS_RSA_WITH_AES_128_GCM_SHA256 0x00, 0x9d, // TLS_RSA_WITH_AES_256_GCM_SHA384 0x00, 0x2f, // TLS_RSA_WITH_AES_128_CBC_SHA 0x00, 0x35, // TLS_RSA_WITH_AES_256_CBC_SHA 0x01, 0x00, // compression methods etc ]; #[inline] #[expect(clippy::too_many_arguments)] pub fn create_tls_client_hello( src_mac: &MacAddr, router_mac: &MacAddr, src_ip: &Ipv4Addr, src_port: u16, dst_ip: &Ipv4Addr, dst_port: u16, sequence: u32, acknowledgement: u32, sni: String, ) -> Option<(Vec, u32)> { use pnet::packet::tcp::MutableTcpPacket; let payload = build_client_hello(&[ServerName::new(sni.as_str())]); let packet_length = PKT_ETH_SIZE + PKT_IPV4_SIZE + PKT_TCP_SIZE + payload.len(); let mut packet = vec![0u8; packet_length]; // Create an Ethernet packet create_ethernet_packet(src_mac, router_mac, &mut packet)?; // Create an IPv4 packet create_ipv4_packet( src_ip, dst_ip, (PKT_IPV4_SIZE + PKT_TCP_SIZE + payload.len()) as u16, Some(IpNextHeaderProtocols::Tcp), &mut packet[PKT_IPV4_OFFSET..], )?; // Create a TCP packet { let mut tcp_packet = MutableTcpPacket::new(&mut packet[PKT_TCP_OFFSET..])?; tcp_packet.set_source(src_port); tcp_packet.set_destination(dst_port); tcp_packet.set_sequence(sequence); tcp_packet.set_acknowledgement(acknowledgement); tcp_packet.set_data_offset(5); tcp_packet.set_window(65535); tcp_packet.set_flags(ACK); tcp_packet.set_payload(&payload); let checksum = pnet::util::ipv4_checksum( tcp_packet.packet(), 8, &[], src_ip, dst_ip, IpNextHeaderProtocols::Tcp, ); tcp_packet.set_checksum(checksum); } Some((packet, payload.len() as u32)) } /// Representation of a Server Name in SNI extension #[derive(Debug)] pub struct ServerName { name: String, kind: ServerNameKind, } impl ServerName { /// Create a new ServerName using the only known type of HostName pub fn new(name: &str) -> ServerName { ServerName { name: name.to_string(), kind: ServerNameKind::HostName, } } /// Get the length of the ServerName pub fn len(&self) -> u16 { self.name.len() as u16 } /// Check if the ServerName is empty pub fn is_empty(&self) -> bool { self.name.is_empty() } /// Get the bytes of the ServerName within a SNI extension. This includes /// the lenght of the ServerName and the ServerName itself. pub fn as_bytes(&self) -> Vec { let mut bytes = Vec::new(); bytes.push(self.kind as u8); // println!("ServerName kind {:?} bytes: {:02X?}", self.name, bytes); bytes.extend_from_slice(&self.len().to_be_bytes()); // println!( // "ServerName len {:?} bytes: {:02X?}", // &self.name.len().to_be_bytes(), // bytes // ); bytes.extend_from_slice(self.name.as_bytes()); // println!("ServerName {:?} bytes: {:02X?}", self.name, bytes); bytes } } /// SNI Server Name Types /// /// Only DNS hostnames are supported at the moment, but we're going to try and /// fuzz for other types. Lets use the types defined in the SAN extensions to /// see if there are any parsers dynamically getting the types here. /// /// It's also worth trying to play with the ASCII characters as described in the /// [RFC 4366](https://www.rfc-editor.org/rfc/rfc4366#section-3.1) /// /// ```rfc /// Currently, the only server names supported are DNS hostnames; /// however, this does not imply any dependency of TLS on DNS, and other /// name types may be added in the future (by an RFC that updates this /// document). TLS MAY treat provided server names as opaque data and /// pass the names and types to the application. /// /// "HostName" contains the fully qualified DNS hostname of the server, /// as understood by the client. The hostname is represented as a byte /// string using UTF-8 encoding [UTF8], without a trailing dot. /// /// If the hostname labels contain only US-ASCII characters, then the /// client MUST ensure that labels are separated only by the byte 0x2E, /// representing the dot character U+002E (requirement 1 in Section 3.1 /// of [IDNA] notwithstanding). If the server needs to match the /// HostName against names that contain non-US-ASCII characters, it MUST /// perform the conversion operation described in Section 4 of [IDNA], /// treating the HostName as a "query string" (i.e., the AllowUnassigned /// flag MUST be set). Note that IDNA allows labels to be separated by /// any of the Unicode characters U+002E, U+3002, U+FF0E, and U+FF61; /// therefore, servers MUST accept any of these characters as a label /// separator. If the server only needs to match the HostName against /// names containing exclusively ASCII characters, it MUST compare ASCII /// names case-insensitively. /// ``` /// /// The [rfc 6066](https://www.rfc-editor.org/rfc/rfc6066#page-6) is a follow up /// and it states that there can only be a single entry of a specific type. /// /// ```rfc /// The ServerNameList MUST NOT contain more than one name of the same name_type. /// /// Note: Earlier versions of this specification permitted multiple /// names of the same name_type. In practice, current client /// implementations only send one name, and the client cannot /// necessarily find out which name the server selected. Multiple /// names of the same name_type are therefore now prohibited. /// ``` #[derive(Debug, Copy, Clone)] pub enum ServerNameKind { /// The fully qualified DNS hostname of the server HostName = 0x00, /// An invalid ServerName Type Invalid = 0xff, } /// Compose a SNI extension for a domain name /// /// 0x00 0x00 indicates its the server_name extension /// 0x00 0x14 is the length of the server name extension /// 0x00 0x12 is the length of the server name list /// 0x00 is the server name type (host_name) /// 0x00 0x0f is the length of the server name /// /// The [SNI extension](https://datatracker.ietf.org/doc/html/rfc6066#page-6) states the following: /// /// The ServerNameList MUST NOT contain more than one name of the same name_type /// /// I can be interesting to see how servers handle this. /// /// 0x77,0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d is the servername in this case 'www.example.com' /// let example_sni: &[u8] = &[ /// 0x00, 0x00, 0x00, 0x14, 0x00, 0x12, 0x00, 0x00, 0x0f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, /// 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, /// ]; pub fn create_sni_extension(server_names: &[ServerName]) -> Vec { let mut sni_extension = Vec::new(); for server_name in server_names { sni_extension.append(&mut server_name.as_bytes()); } let mut extension_header = BytesMut::new(); extension_header.extend_from_slice(&[0x00, 0x00]); // Extension type for SNI let extension_length = sni_extension.len() as u16 + 2; // Includes server name list length field extension_header.extend_from_slice(&extension_length.to_be_bytes()); // Extension length let server_name_list_length = sni_extension.len() as u16; extension_header.extend_from_slice(&server_name_list_length.to_be_bytes()); // Server name list length sni_extension.splice(0..0, extension_header); // println!("SNI Extension: {:02X?}", sni_extension); sni_extension } /// Builds a TLS Client Hello message with a SNI extension fn build_client_hello(server_names: &[ServerName]) -> Vec { // We need to compute the SNI extension, append it to the body, and add the total length of // extensions to the body. // We compose the Client Hello record here // We try to hardcode as much as possible to avoid // unnessesary computation // Layout of the Client Hello Data minus the headers and extensions // 2 bytes Client Version // 32 bytes Random // 1 byte Session ID Length // 2 bytes of cypher suites length // x bytes of cypher suites // 1 byte of compression methods length // x bytes of compression methods let client_hello_data = TLS_12_CLIENT_HELLO.as_ref(); // Some of the default extensions we'll need. Captured from request to www.example.com let default_extensions: &[u8] = &[ 0x00, 0x23, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x2a, 0x00, 0x28, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x03, 0x03, 0x03, 0x01, 0x03, 0x02, 0x04, 0x02, 0x05, 0x02, 0x06, 0x02, ]; let body: &[u8] = if !server_names.is_empty() { // Generate our SNI extension based on the given server names let sni_extension = create_sni_extension(server_names); // Calculate the total lenght of our extensions let length_of_extensions: [u8; 2] = (default_extensions.len() as u16 + sni_extension.len() as u16).to_be_bytes(); &[ client_hello_data, &length_of_extensions, default_extensions, &sni_extension, ] .concat() } else { let length_of_extensions: [u8; 2] = (default_extensions.len() as u16).to_be_bytes(); &[client_hello_data, &length_of_extensions, default_extensions].concat() }; // Layout of the Handshake Header // 01 - handshake message type 0x01 (client hello) // 3 bytes of length 0xXX 0xXX 0xXX - (x bytes of client hello follows) let length_of_client_hello: [u8; 4] = (body.len() as u32).to_be_bytes(); let handshake_header: &[u8] = &[ 0x01, length_of_client_hello[1], length_of_client_hello[2], length_of_client_hello[3], ]; // Layout of the Record Header // 16 - type is 0x16 (handshake record) // 03 01 - protocol version is 3.1 (also known as TLS 1.0) Some TLS servers fail if the record version is greater than TLS1.0 for initial ClientHello // 2 bytes of length 0xXX 0xXX (xx bytes of handshake message follows) let body: &[u8] = &[handshake_header, body].concat(); let length_of_record: [u8; 2] = (body.len() as u16).to_be_bytes(); let record_header: &[u8] = &[0x16, 0x03, 0x01, length_of_record[0], length_of_record[1]]; [record_header, body].concat() } /// Error type for parsing TLS 1.2 records /// /// This error type is used when parsing TLS 1.2 records and the record is either incomplete or invalid. #[derive(Debug)] pub enum TlsExtractError { Incomplete, Invalid, } impl fmt::Display for TlsExtractError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { TlsExtractError::Incomplete => write!( f, "Body is incomplete and does not contain a full TLS 1.2 record." ), TlsExtractError::Invalid => { write!(f, "Did not receive a valid ServerHello TLS 1.2 response.") } } } } impl std::error::Error for TlsExtractError {} /// Parse the TLS 1.2 Server Hello response and extract the certificate chain /// /// # Arguments /// * `body` - The body of the response containing the Server Hello message /// /// # Returns /// * `Result>` - The extracted certificate bytes /// Shared inline helper to advance a TLS 1.2 ServerHello + (possibly partial) Certificate record /// to the start of the Certificate handshake payload (certificate handshake header). /// /// It returns a tuple: (certificate_record_payload_slice, expected_certificate_record_payload_len). /// The slice may be shorter than the expected payload length if the full record has not yet been /// received. This enables callers that only need the leaf certificate to proceed as soon as its /// bytes are available, without waiting for the entire chain. /// /// We only enforce completeness for the first (ServerHello) record because we cannot position /// ourselves at the second record safely without its full length. For the certificate record we /// allow partial data and defer completeness decisions to higher-level functions. #[inline] fn advance_to_tls12_certificate_handshake( mut data: &[u8], ) -> Result<(&[u8], usize), TlsExtractError> { // Need at least first record header if data.len() < 5 { return Err(TlsExtractError::Incomplete); } if data[0] != 0x16 || data[1] != 0x03 || data[2] != 0x03 { return Err(TlsExtractError::Invalid); // Not a TLS1.2 Handshake record } let server_hello_record_len = ((data[3] as usize) << 8) | (data[4] as usize); let after_record = 5 + server_hello_record_len; if data.len() < after_record { // Need full ServerHello record to skip it return Err(TlsExtractError::Incomplete); } data = &data[after_record..]; // Second record header (Certificate). Allow partial payload. if data.len() < 5 { // Not enough for second record header yet return Err(TlsExtractError::Incomplete); } if data[0] != 0x16 || data[1] != 0x03 || data[2] != 0x03 { // Wrong content type/version return Err(TlsExtractError::Invalid); } let cert_record_len = ((data[3] as usize) << 8) | (data[4] as usize); let available_payload = data.len().saturating_sub(5); let payload_slice = &data[5..5 + available_payload.min(cert_record_len)]; // We do NOT error if available_payload < cert_record_len here; caller decides. Ok((payload_slice, cert_record_len)) } /// Extract the full TLS 1.2 certificate chain bytes (certificate_list) from a ServerHello response. /// Returns the raw certificate_list (3-byte length prefix + concatenated certificates) as a Vec. /// Drops all preceding handshake/record data immediately to minimize memory usage. pub fn extract_tls12_certificate_chain(body: &[u8]) -> Result, TlsExtractError> { let (payload, cert_chain_length) = advance_to_tls12_certificate_handshake(body)?; // Need at least handshake header (4 bytes) + cert_list length (3 bytes) if payload.len() < 7 { return Err(TlsExtractError::Incomplete); } if payload[0] != 0x0b { return Err(TlsExtractError::Invalid); } // Certificate handshake header: type (1) + length (3) let handshake_cert_list = &payload[4..]; // Need certificate_list length (3 bytes) if handshake_cert_list.len() < 3 { return Err(TlsExtractError::Incomplete); } let cert_list_len = ((handshake_cert_list[0] as usize) << 16) | ((handshake_cert_list[1] as usize) << 8) | (handshake_cert_list[2] as usize); let total_needed_in_record = 4 + 3 + cert_list_len; // handshake hdr + list len + list if payload.len() < total_needed_in_record || cert_chain_length < total_needed_in_record { // Not all chain bytes received yet => Incomplete return Err(TlsExtractError::Incomplete); } // Extract certificate_list (length prefix + concatenated certs) starting at cert_list length field let cert_list_with_len = &payload[4..total_needed_in_record]; let mut out = Vec::with_capacity(cert_list_with_len.len()); out.extend_from_slice(cert_list_with_len); Ok(out) } /// Extract only the leaf certificate (first certificate in the certificate_list) from a TLS 1.2 /// ServerHello + Certificate sequence. Returns DER bytes of the leaf certificate. pub fn extract_tls12_leaf_certificate(body: &[u8]) -> Result, TlsExtractError> { let (payload, _cert_chain_length) = advance_to_tls12_certificate_handshake(body)?; // Need handshake header + cert_list length + first cert length if payload.len() < 10 { return Err(TlsExtractError::Incomplete); } if payload[0] != 0x0b { return Err(TlsExtractError::Invalid); } // After handshake header (4 bytes) comes certificate_list length (3 bytes) let after_handshake = &payload[4..]; if after_handshake.len() < 6 { return Err(TlsExtractError::Incomplete); } // need list len (3) + first cert len (3) let leaf_len = ((after_handshake[3] as usize) << 16) | ((after_handshake[4] as usize) << 8) | (after_handshake[5] as usize); let leaf_start = 6; // start of leaf cert bytes within after_handshake let leaf_end = leaf_start + leaf_len; if after_handshake.len() < leaf_end { return Err(TlsExtractError::Incomplete); } let leaf = &after_handshake[leaf_start..leaf_end]; // Validate ASN.1 SEQUENCE start if let Ok((tag, _, _, _)) = parse_tlv(leaf) && tag == TAG_SEQUENCE { return Ok(leaf.to_vec()); } Err(TlsExtractError::Invalid) } ``` ### `scanners/mrpf_engine/src/progress.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/progress.rs` ```rust use std::sync::mpsc::Receiver; /// Handle receiving of scanner progress messages. /// /// The handler method will be called from the scanners progress thread. This trait /// will allow you to implement different handlers for the same scanner. For example, /// /// - a simple default handler that only returns the result of the scan /// - a handler that prints the progress to the CLI and returns the result. /// - a handler that integrates with a web service that can display the progress and results pub trait ProgressHandler: Send + Sync + 'static { /// The message that will be sent to the handler type Message; /// The collection of results that will be returned when the scan is done type Results; /// The handler function for incoming messages fn handler(&mut self, receiver: Receiver) -> Self::Results; } ``` ### `scanners/mrpf_engine/src/ratelimit.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/ratelimit.rs` ```rust use std::thread::sleep; use std::time::{Duration, Instant}; /// Rate limiter for controlling the rate of packet transmission. /// /// The rate limit bucket algorithm is a mechanism to control the rate of packet transmission. /// This algorithm ensures that packets are sent at a controlled rate, preventing network congestion and ensuring /// fair usage of network resources. /// /// 1. **Initialization**: /// - A token bucket is initialized with a certain number of tokens (`packets_per_interval`). /// - Each token represents permission to send one packet. /// - The bucket is refilled at regular intervals (`interval`). /// /// 2. **Packet Transmission**: /// - For each packet to be sent, the algorithm checks if there are tokens available in the bucket. /// - If tokens are available, a token is consumed, and the packet is sent. /// - If no tokens are available, the algorithm waits until the bucket is refilled. /// /// 3. **Refilling the Bucket**: /// - The bucket is refilled at a fixed interval (`interval`). /// - When the interval elapses, the bucket is refilled to its maximum capacity (`packets_per_interval`). /// #[derive(Debug, PartialEq, Clone)] pub struct RateLimiter { interval: Duration, packets_per_interval: u32, tokens: u32, last_refill: Instant, } impl RateLimiter { /// Initialize RateLimiter with packets per second. /// /// The interval will be converted down to 50 milliseconds to avoid write buffer congestion. /// /// # Arguments /// /// * `packets_per_second` - The maximum number of packets allowed per interval. /// /// # Returns /// /// A new RateLimiter instance. pub fn new(packets_per_second: u32) -> RateLimiter { let interval = Duration::from_millis(50); let factor = Duration::from_secs(1).as_secs_f64() / interval.as_secs_f64(); let packets_per_interval = (packets_per_second as f64 / factor).ceil() as u32; RateLimiter { interval, packets_per_interval, tokens: packets_per_interval, last_refill: Instant::now(), } } /// Initialize RateLimiter with custom interval and packets per interval. /// /// Note that setting a long (e.g. 1 second) interval will cause the rate limiter /// to send all packets in the first few microseconds. This will likely cause /// write buffer congestion. /// /// # Arguments /// /// * `interval` - The interval duration for rate limiting. /// * `packets_per_interval` - The maximum number of packets allowed per interval. /// /// # Returns /// A new RateLimiter instance. pub fn new_with_interval(interval: Duration, packets_per_interval: u32) -> RateLimiter { RateLimiter { interval, packets_per_interval, tokens: packets_per_interval, last_refill: Instant::now(), } } /// Get the rate limits per second pub fn get_rate_limit_per_second(&self) -> u32 { let interval_secs = self.interval.as_secs_f64(); (self.packets_per_interval as f64 / interval_secs).round() as u32 } /// Start the rate limiter. /// /// This will reset the the tokens and last refill time if the interval has passed. pub fn start(&mut self) { self.tokens = self.packets_per_interval; self.last_refill = Instant::now(); } /// Wait until we can send a packet. /// /// If we have tokens available, we consume a token and return true. /// If the rate limit bucket is empty, we refill the bucket and sleep. pub fn wait_for_token(&mut self) { loop { if self.tokens > 0 { // Consume a token and signal we're ready to send self.tokens -= 1; break; } else { // Rate limit bucket is empty, refill and sleep let elapsed = self.last_refill.elapsed(); if elapsed >= self.interval { self.tokens = self.packets_per_interval; self.last_refill = Instant::now(); } else { // println!( // "TODO: should change the interval to something a lot less than a second to avoid sending all packets in the first few microseconds. Empty bucket, refilling and sleeping {:?}", // self.interval - elapsed // ); sleep(self.interval - elapsed); } } } } } impl Default for RateLimiter { /// Defaults to 10,000 packets per second. /// /// Note the actual settings are 500 packets per 50 milliseconds to avoid write buffer congestion. fn default() -> Self { RateLimiter { interval: Duration::from_millis(50), packets_per_interval: 500, tokens: 500, last_refill: Instant::now(), } } } // The base statement we assume is that a single ip can process no more than 10 packet per second. const MAX_PACKETS_PER_SECOND_PER_DESTINATION: u32 = 10; const MAX_PACKETS_PER_DESTINATION_INTERVAL: Duration = Duration::from_secs(1); /// Calculates the packet rate limit and interval duration based on the number of ports and IP addresses. /// /// TODO: This isn't working yet, need to look at the code and think about it more. /// /// 1. max rate limit is for my own machine (e.g. regardless of network, i can send x amount of packets every x milliseconds) /// 2. If the size of the total ports * addresses to scan is 'large enough' I should use the max rate limit for my own machine /// 3. If the size of the total ports * addresses to scan is 'too large' for the destination hosts to handle, I should scale down the x amount of packets downward. /// /// TODO: There's actually another criteria that i haven't considered. Rate limits are mainly to avoid overwhelming a single server OR network. /// A single server can progress a scan of one or two ports without any problems, it's when we're scanning a lot of ports that we run into issues. /// the ports * targets therefore is very naive. /// /// # Arguments /// * `num_ports` - The number of ports to scan. /// * `num_ips` - The number of IP addresses to scan. /// * `max_rate_limit_per_interval` - The maximum rate limit for your machine (packets per interval). /// * `interval` - The fixed interval duration for rate limiting. Since our scanner is pushing out packets as fast as it can, make sure /// to not make this interval too long (e.g. 10 to 100 milli seconds could be good?) /// /// # Returns /// A tuple containing the packet rate limit (packets per interval) and the interval duration. #[expect(dead_code)] pub(crate) fn calculate_rate_limit( num_ports: u16, num_ips: u32, max_rate_limit_per_interval: u32, interval: Duration, ) -> u32 { // The total number of packets to send let total_packets = num_ports as u32 * num_ips; // Calculate the rate limit using an exponential function let rate_limit = if total_packets <= MAX_PACKETS_PER_SECOND_PER_DESTINATION { 1 } else { ((total_packets as f64).exp()) as u32 }; println!("{}", rate_limit); // Adjust the rate limit based on the given interval compared to the base interval let adjusted_rate_limit = (rate_limit as f64 * (interval.as_secs_f64() / MAX_PACKETS_PER_DESTINATION_INTERVAL.as_secs_f64())) as u32; let adjusted_max_rate_limit = (max_rate_limit_per_interval as f64 * (interval.as_secs_f64() / MAX_PACKETS_PER_DESTINATION_INTERVAL.as_secs_f64())) as u32; // Ensure the rate limit does not exceed the maximum rate limit or is 0 adjusted_rate_limit.min(adjusted_max_rate_limit).max(1) } #[cfg(test)] mod tests { use super::*; #[test] fn test_rate_limiter_new() { let rate_limiter = RateLimiter::new(1000); assert_eq!(rate_limiter.interval, Duration::from_millis(50)); assert_eq!(rate_limiter.packets_per_interval, 50); assert_eq!(rate_limiter.tokens, 50); } #[test] fn test_rate_limiter_new_with_interval() { let interval = Duration::from_secs(1); let rate_limiter = RateLimiter::new_with_interval(interval, 10); assert_eq!(rate_limiter.interval, interval); assert_eq!(rate_limiter.packets_per_interval, 10); assert_eq!(rate_limiter.tokens, 10); } #[test] fn test_rate_limiter_get_rate_limit_per_second() { let rate_limiter = RateLimiter::new(1_000); assert_eq!(rate_limiter.get_rate_limit_per_second(), 1_000); } #[test] fn test_rate_limiter_start() { let mut rate_limiter = RateLimiter::new(1000); rate_limiter.start(); assert_eq!(rate_limiter.tokens, rate_limiter.packets_per_interval); } #[test] fn test_rate_limiter_wait_for_token() { let mut rate_limiter = RateLimiter::new(1); rate_limiter.start(); rate_limiter.wait_for_token(); assert_eq!(rate_limiter.tokens, 0); } #[test] fn test_rate_limiter_default() { let rate_limiter = RateLimiter::default(); assert_eq!(rate_limiter.interval, Duration::from_millis(50)); assert_eq!(rate_limiter.packets_per_interval, 500); assert_eq!(rate_limiter.tokens, 500); assert_eq!(rate_limiter.get_rate_limit_per_second(), 10_000); } } ``` ### `scanners/mrpf_engine/src/utils.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_engine/src/utils.rs` ```rust use pnet::datalink::{self, NetworkInterface}; use serde::Serialize; use std::fs::File; use std::net::{IpAddr, Ipv4Addr}; use tracing::debug; /// Retrieve the network interface to use for scanning /// /// If an interface name is provided, the function will return the interface with the specified name. /// Otherwise, it will return the first non-loopback interface that is up. /// /// # Arguments /// /// * `interface_name` - An optional string slice containing the name of the interface to use /// /// # Returns /// /// A `NetworkInterface` object representing the network interface to use /// /// # Example /// /// ```ignore /// use mrpf::utils::get_interface; /// /// let interface = get_interface(None); /// debug!("{:?}", interface); /// // { name: "...", index: ... } /// ``` pub fn get_interface(interface_name: Option<&str>) -> NetworkInterface { let interfaces = datalink::interfaces(); // TMP Print all interfaces for iface in &interfaces { debug!( "Interface: {} - MAC: {:?} - IPs: {:?} - is_up: {} - is_loopback: {}", iface.name, iface.mac, iface .ips .iter() .map(|ip| ip.ip()) .collect::>(), iface.is_up(), iface.is_loopback() ); } match interface_name { Some(interface_name) => interfaces .into_iter() .find(|iface| iface.name == interface_name) .expect("No network interface found with given name"), // Complaining about function call? // .expect(&format!( // "No network interface found with the specified name '{:}'", // interface_name // )), None => interfaces .into_iter() .find(|iface| iface.is_up() && !iface.is_loopback()) .expect("No suitable network interface found"), } } /// Retrieve the first IPv4 address from a given network interface /// /// # Arguments /// /// * `interface` - A reference to a `NetworkInterface` object /// /// # Returns /// /// An `Ipv4Addr` object representing the first IPv4 address found on the interface. pub fn get_interface_ip(interface: &NetworkInterface) -> Ipv4Addr { interface .ips .iter() .find(|network| network.is_ipv4()) .map(|network| match network.ip() { IpAddr::V4(ipv4) => ipv4, _ => unreachable!(), }) .expect("No IPv4 address found") } /// Very simple function to write a JSON file because I'm lazy /// /// # Arguments /// /// * `filename` - A string slice containing the name of the file to write /// * `data` - The data to write to json file. This can be anything /// that serde is able to serialize like a HashMap, Vec, scalars, etc pub fn write_json_file(filename: &str, data: &T) { let file = File::create(filename).expect("Unable to create file"); serde_json::to_writer_pretty(file, data).expect("Unable to write data"); } ``` ### `scanners/mrpf_http1_scanner/FCK.md` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/FCK.md` ```markdown # pfff ok my http1 scanner inputs is interesting What do I want to scan? 1. A list of ip addresses with unique SNI's for each ip and all it's ports 2. A list of ip addresses with the same (or None) SNI for each ip and all it's ports 3. A single IP address with a single SNI for all it's ports 4. Vhost scanning for a single IP address, with multiple SNI's and checking the response (would our tls scanner be better for this? That would mean we require the host to return a unique certificate) ``` ### `scanners/mrpf_http1_scanner/Problem.md` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/Problem.md` ```markdown # Problem The send/receive http doesn't work well with templates. Since the templates can generate quite a few different combinations which are not known in advance (unless you already calculate everything before starting the scan which is cpu and mem intensive), we don't know when to stop sending, or how to glue multiple generated bodies for the same host together. ## Some thoughts One thing thats definitely needed: - Our connection table should contain the template iterator. It can then iterate on them until everything is generated. - TLS can be re-used but ONLY when HTTP does connection: keep-alive. Most flexible way would be to somehow trigger the receive function again for the next template iteration. It should then check if there is still a connection open and if so, re-use it. If not, it should open a new one. This process repeats until the template iterator is done. - This might mean that for our HTTP scanner, the transmit thread should not be in control to determine when we're done handling things, but it should be the receive thread. We can then loop our send function until the receive thread signals its done Perhaps flipping this send and receive on it's head will make the code more robust for other scanner implementations as well. When the transmit function is only just sending packets it receives from the receiver, we don't need to pass any iterator logic to the transmit function. That means the function can be re-used across scanners. It will also mean it will be easier to create a pause/stop function in the scanner, as the receive function can better coordinate this. The biggest challenge though with this is that the recv function is blocking until it sees a new packet.... ## More thoughts Turning it around won't fully work. - Update the connection struct to store ip/port, templateiterator, current_generated_template_to_send. - The transmit function should iterate through all ip/ports and send initial tcp_syn packets. NOTE: We will use different src ports so we can parallelize multiple connections going to the same host. - When it's done sending all initial syn packets, it should continue sending all receive packets until it receives a stop message from the receive thread - The receive thread will use the various TLS states to determine if it needs to re-establish a new tcp connection (by sending a tcp syn packet to the send queue), or if it can leverage the existing connection, generate the next current_generated_template_to_send, and send it over the existing TLS connection. NOTE: THIS DOES MEAN THAT GENERATED TEMPLATES TO THE SAME HOST WILL BE PROCESSED IN SEQUENCE.. This is defeating somewhat of the purpose of my utra-fast scanner :( ## Again rethinking this The last statement in previous section highlights again incorrect thinking. Perhaps the key is to separate the template generation from the ip/port/sni tuples? This is going to be difficult as we want to be able to use the ip/port/sni's as variables in the template themselves. Do we need shared connection state between the transmit and receive threads to coordinate this all? I GUESS I NEED TO ANSWER THE FOLLOWING QUESTION: - Do I want to generate a new connection for each template, OR do I want to re-use connections for the same host? New connection per generated template: - Biggest potential parallelism - No potential for re-using connections, so in the end will have caused more connections to be made than necessary Ok again an idea that is somewhere in the middle. - Maybe we generate a new template every time a new packet comes in. This will allow us to build up a pool of open TLS connections to the same host, whilst still trying to optimize for parallelism. I think it's just too complex and overkill for now to support TLS connection re-use. Most robust option might be now to share the connection table across threads. The transmit therad will make the connection table entries by syn cookie. It will also generate the template to send. For each new template, it will iterate through the port and ip address again, ensuring it will naturally spread out the connections over the hosts and have unique entries per template. This can clean up some of the code in the receive thread as well. Looking at this from a locking perspective, sharing the connection table is problematic as the transmit and receive threads will be competing for the same table lock. Perhaps we should establish a channel from transmit to receive to send over the connection table. HOWEVER, could it be we're already receiving packets before we've processed the new connections? Maybe not if we handle new connections directly after the recv.next call before doing anything else? The connection table entries will need a connection timeout so we can remove entries that are not responding. ## Detecting when to stop I'm again approaching it wrong by moving it to the receive function. The key is to have BOTH transmit and receive threads signal they are finished. Transmit will signal it's done with whatever initial TCP syns, the receive thread will signal it's done with whatever connection table it holds. When both are done, we can stop the scanner. Note that the receive thread often will add new packets to the transmit thread, so the transmit thread needs to both track initial syn + incoming packets. Not 100% sure how we can sync these up: if transmit.done_sending_syn and transmit.queue.isEmpty. Will the isEmpty check possibly be true before we've processed all incoming packets from recv queue? Yes that is a possibility. This means the recv process also need to be able to signal to the transmit queue that it's done processing anything. This will do away with the annoying 5 second wait timer we have now, whihc makes single packet sending a lot slower than needed. ``` ### `scanners/mrpf_http1_scanner/src/Howto.md` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/Howto.md` ```markdown # How to implement my HTTP 1/x with TLS server in Rust? The TcpConnectionTable is a good way to manage TCP connections. Then I only need the first reassembled packet from the server in my TLS Scanner before closing the connection. So effectively, if we think about the encapsulation: TLS Scanner: Ethernet -> IP -> TCP -> Payload we're after For a HTTP scanner plain text this would be: Ethernet -> IP -> TCP -> Payload we're after However, for any protocol thats encapsulated in TLS, it will be: Ethernet -> IP -> TCP -> TLS -> Payload we're after So, to solve my Http1Scanner, Perhaps I need to work with a TlsConnectionTable instead of the TcpConnectionTable inside my receive function. The TlsConnectionTable will internally have a TcpConnectionTable instance. ## update 2 Ok, started to fill in some of the blanks. There are cases where I'm extracting the connection in the receive function more than once. This is less optimal. What if we make the TlsConnectionTable a trait with an associated type for the connection. Then I can create connection types for my different implementations. These connections will then expose the functions I'm currently handling through the connection, like reassemble, etc. ```rust pub trait TcpConnectionTable { type Connection; fn new() -> Self; fn insert(&mut self, ip_address: Ipv4Addr, port: u16, connection: Self::Connection); fn get(&self, ip_address: Ipv4Addr, port: u16) -> Option<&Self::Connection>; fn remove(&mut self, ip_address: Ipv4Addr, port: u16); } ``` ## Update 3 Considering using rustls here for processing the TLS packages. They offer an unbuffered example that *should* allow me to bring my own network stack. Here's the client example, it seems the code has to call process_tls_packages, which will change the internal state to connected eventually. https://github.com/rustls/rustls/blob/main/examples/src/bin/unbuffered-client.rs Have a read over this issue as well: https://github.com/rustls/rustls/issues/1958 ``` ### `scanners/mrpf_http1_scanner/src/connections.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/connections.rs` ```rust use super::tls::get_tls_client_config; use mrpf_engine::{Connection, TcpReassemblyBuffer}; use rustls::client::UnbufferedClientConnection; use rustls::unbuffered::{AppDataRecord, ConnectionState, UnbufferedStatus}; use std::sync::Arc; #[derive(PartialEq, Clone)] pub enum TlsState { /// Do nothing DoNothing, /// Send packet Send(usize), /// TLS Client should close Finish, /// Connection Error Error, } /// A TLS Connection pub struct TlsConnection { /// Holds the result buffer: TcpReassemblyBuffer, /// The next sequence number our connection will use for sending packets to the remote host next_seq: u32, /// The next ack number our connection should confirm next_ack: u32, /// Rustls TLS client connection conn: UnbufferedClientConnection, // Bytes that will contain incoming TLS data // // This will be filled in by us when we receive new packets from the network. incoming_tls: Vec, /// The unencrypted payload to send payload: Vec, /// Set when the unencrypted payload has been sent payload_sent: bool, /// The SNI /// /// TODO: Getters pub sni: String, /// The HTTP response received /// TODO: Getters pub response: Vec, // Timeout time // timeout: Instant, } // TLS session is establised impl Connection for TlsConnection { /// Add data to an existing connection /// /// The given data will be added to the connection buffer. /// /// # Arguments /// /// * `data` - The data to add to the connection /// * `seq` - The sequence number of the data /// /// # Returns /// /// * `bool` - True if the data was added to the connection, false if the connection does not exist fn add_segment(&mut self, data: &[u8], seq: u32) { self.buffer.insert_segment(seq, data.to_vec()); // if self.is_established() { // // TODO: Decrypt the data using the key // self.buffer.insert_segment(seq, data.to_vec()); // } else { // self.buffer.insert_segment(seq, data.to_vec()); // }; } /// Get the reassembled data for a connection /// /// After processing the reassembled data, clear_buffer should be called before /// processing further packets. /// /// # Arguments /// * `ip_address` - The IP address of the connection /// * `port` - The port of the connection /// /// # Returns /// * `Option>` - The decrypted buffer for the connection, or None if the TCP segment is not yet complete /// or the connection does not exist. fn get_reassembled_buffer(&mut self) -> Option> { self.buffer.reassemble() } /// Get the next sequence number fn next_seq(&self) -> u32 { self.next_seq } } impl TlsConnection { pub fn payload(&self) -> Vec { self.payload.clone() } pub fn set_next_seq(&mut self, seq: u32) { self.next_seq = seq; } /// Get the next ack number pub fn next_ack(&self) -> u32 { self.next_ack } /// Set the next ack number pub fn set_next_ack(&mut self, ack: u32) { self.next_ack = ack; } // Check if connection timeout is exceeded // TODO: Move this to the Connection trait and integrate it into // The TcpConnectionTable so we can clear out all connections that // have timed out. // pub fn connection_timed_out(&self, current_time: Instant) -> bool { // current_time > self.timeout // } /// Create a new TlsConnection /// /// # Arguments /// /// * `server_name` - The server name of the connection /// * `next_seq` - The next sequence number our connection will use for sending packets to the remote host /// * `next_ack` - The next ack number our connection should confirm pub fn new(server_name: String, payload: Vec, next_seq: u32, next_ack: u32) -> Self { let buffer = TcpReassemblyBuffer::new(next_seq); let config = get_tls_client_config(); let conn = UnbufferedClientConnection::new( Arc::clone(&config), server_name.clone().try_into().unwrap(), ) .unwrap(); Self { buffer, next_seq, next_ack, conn, // TODO: What is average size of a incoming TLS Server? I've seen 2 * mtu(1460) + a bit more incoming_tls: Vec::with_capacity(1460 * 2), payload, payload_sent: false, sni: server_name, response: Vec::with_capacity(1460), // timeout: Instant::now() + std::time::Duration::from_secs(10), } } /// Process TLS Messages /// /// # Arguments /// /// * `incoming_payload` - The incoming payload /// * `outgoing_tls` - The outgoing TLS payload /// /// # Returns /// /// * `TlsState` - The state of the TLS connection pub fn process(&mut self, incoming_payload: &[u8], outgoing_tls: &mut [u8]) -> TlsState { // TODO: We are not properly sorting the incoming packets here, we should use // the TcpReassemblyBuffer to do this. This will mean we need to handle the // discard better OR we need to reassemble, send and then clear the buffer // again? self.incoming_tls.extend_from_slice(incoming_payload); let mut outgoing_bytes: Option = None; let UnbufferedStatus { mut discard, state } = self.conn.process_tls_records(&mut self.incoming_tls); if let Err(error) = state { println!("Error: {:?}", error); return TlsState::Error; } match state.unwrap() { ConnectionState::EncodeTlsData(mut state) => { // A Handshake record is ready for encoding match state.encode(outgoing_tls) { Ok(handshake_len) => { outgoing_bytes = Some(handshake_len); } Err(_) => { panic!("Insufficient size?"); } } } ConnectionState::ReadEarlyData(mut _state) => { // println!("Early data"); todo!("Handle early data"); } ConnectionState::ReadTraffic(mut state) => { // println!("Read traffic"); while let Some(res) = state.next_record() { let AppDataRecord { discard: new_discard, payload, } = res.unwrap(); discard += new_discard; self.response.extend_from_slice(payload); } // This will send an ACK for the received data // TODO: This is where we would want to look at the window size to avoid sending // an ACK for each received segment. outgoing_bytes = Some(0); } ConnectionState::BlockedHandshake => { // println!("Blocked handshake"); } ConnectionState::WriteTraffic(mut encrypter) => { // println!("Write traffic"); // The TLS connection allows us to write data // It could be we still need to send our payload, we've received a partial http response // or we've received a full http response but we use connection: keep-alive to allow // us to send more requests. if !self.payload_sent { let encrypted_payload_len = encrypter .encrypt(&self.payload, outgoing_tls) .expect("encrypted request does not fit in `outgoing_tls`"); outgoing_bytes = Some(encrypted_payload_len); self.payload_sent = true; } else { // Send Acknowledgements? I don't think thats needed but... // TODO: Close the connection here for now as we don't want to re-use connections with keep-alive at the moment. // The problem is that we could still have an incompleted http response so we should only finish when // we know the http response is completed. For now we'll just ack the packet here. This does mean we // get stuck when we use connection: keep-alive in our http request. // return TlsState::Finish; outgoing_bytes = Some(0); } } ConnectionState::Closed => { // Here the remote host signals that it is closing the connection // When using connection: close, we know for sure that they've sent what they wanted to send. // However, with connection: keep-alive, we would need to parse the HTTP packet to determine // it fully received the response. So best thing is to refactor my code to send all data during // the write traffic state. In the ReadTraffic state we should be parsing the HTTP // message and determine if we have the full response. If so, we update the connnection for this // ip/port to generate the new template to send. If we get a Finish state, we should create a // new TCP Syn packet for our sender. If we get a WriteTraffic state, we can start sending the // next HTTP request. println!("Finish"); return TlsState::Finish; } ConnectionState::TransmitTlsData(mut state) => { // println!("Transmit TLS data"); if let Some(mut _may_encrypt_early_data) = state.may_encrypt_early_data() { // println!("May encrypt early data"); // may_encrypt_early_data.encrypt(b"hello", outgoing_tls); todo!("Handle early data"); // sent_early_data = true; } if let Some(mut encrypter) = state.may_encrypt_app_data() { // println!("May encrypt app data"); if !self.payload_sent { // let request = "GET /wp-content/uploads/2008/09/cropped-heritage-535.jpg HTTP/1.1\r\nHost: www.mortimer.nl\r\nConnection: close\r\nAccept-Encoding: identity\r\n\r\n".to_string(); // let request = "GET / HTTP/1.1\r\nHost: www.mortimer.nl\r\nConnection: close\r\nAccept-Encoding: identity\r\n\r\n".to_string(); let encrypted_payload_len = encrypter .encrypt(&self.payload, outgoing_tls) .expect("encrypted request does not fit in `outgoing_tls`"); outgoing_bytes = Some(encrypted_payload_len); self.payload_sent = true; } } // Call state.done when we've sent our data. state.done(); } _ => unreachable!(), } self.incoming_tls.drain(..discard); if let Some(outgoing_bytes) = outgoing_bytes { TlsState::Send(outgoing_bytes) } else { TlsState::DoNothing } } } ``` ### `scanners/mrpf_http1_scanner/src/lib.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/lib.rs` ```rust pub(crate) mod connections; pub mod progress; mod receive; pub mod scanner; pub(crate) mod template; pub(crate) mod tls; mod transmit; pub use crate::progress::Http1Message; pub use crate::scanner::Http1Scanner; pub use crate::template::{ Template, TemplateIterator, TemplateValue, TemplateVariable, TemplateVariables, }; mod models; pub use models::{Http1Request, Http1Response}; ``` ### `scanners/mrpf_http1_scanner/src/models/http1request.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/models/http1request.rs` ```rust use serde::{Deserialize, Serialize}; use std::net::Ipv4Addr; #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "snake_case")] pub struct Http1Request { /// The IP Address to connect to ip: Ipv4Addr, /// The TCP port to connect to port: u16, /// The body of the request body: Vec, /// Whether the request is over TLS is_tls: bool, /// The SNI to use for the request #[serde(default)] sni: Option, } impl Clone for Http1Request { /// Ensure our parsed Request is not cloned fn clone(&self) -> Self { Http1Request { ip: self.ip, body: self.body.clone(), port: self.port, is_tls: self.is_tls, sni: self.sni.clone(), } } } impl Http1Request { /// Create a new HTTP (Non-TLS) request pub fn new_http(ip: Ipv4Addr, body: Vec, port: u16) -> Self { Http1Request { ip, body, port, is_tls: false, sni: None, } } /// Create a new HTTPS request pub fn new_https(ip: Ipv4Addr, body: Vec, port: u16, sni: Option) -> Self { Http1Request { ip, body, port, is_tls: true, sni, } } pub fn ip(&self) -> &Ipv4Addr { &self.ip } pub fn port(&self) -> u16 { self.port } pub fn body(&self) -> &Vec { &self.body } pub fn is_tls(&self) -> bool { self.is_tls } pub fn sni(&self) -> Option<&String> { self.sni.as_ref() } // TODO: Fix this, have to work with lifetimes to get it to work properly. // also might need separate headers param // pub fn parsed(&mut self) -> Option<&httparse::Request<'static, 'static>> { // if self.parsed.is_none() { // let mut headers = [httparse::EMPTY_HEADER; 64]; // let mut req = httparse::Request::new(&mut headers); // if req.parse(&self.body).is_ok() { // self.parsed = Some(req); // } // } // self.parsed.as_ref() // } // pub fn path(&self) -> Option<&str> { // self.parsed().and_then(|r| r.path) // } } ``` ### `scanners/mrpf_http1_scanner/src/models/http1response.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/models/http1response.rs` ```rust use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "snake_case")] pub enum Content { Text(String), Binary(Vec), } #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "snake_case")] pub struct Http1Response { /// The HTTP Status Code status_code: u16, /// The parsed HTTP Headers /// Should I instead create a Header type for easier retrieval of the values? look at HTTP crate perhaps as it has a bunch of useful types headers: HashMap>, /// The raw bytes of the returned status line and headers. Useful if we want to remain the order of the headers /// in our macOS scanner app. raw_headers: Vec, /// The content of the response content: Content, /// The hash of the content content_hash: String, /// The length of the content content_length: usize, } impl Http1Response { /// Initialize the HTTP1Response from given bytes pub fn from_bytes(bytes: Vec) -> Result> { let mut headers = [httparse::EMPTY_HEADER; 64]; let mut response = httparse::Response::new(&mut headers); let parse_status = response.parse(&bytes)?; if parse_status.is_partial() { return Err("Incomplete HTTP response".into()); } let status_code = response.code.ok_or("Missing status code")?; let mut headers = HashMap::new(); for h in response.headers.iter() { let entry = headers.entry(h.name.to_string()).or_insert_with(Vec::new); entry.push(String::from_utf8(h.value.to_vec()).unwrap_or_default()); } let body_start = parse_status.unwrap(); // TODO: I think I could split the headers and body here for a slight memory optimization? let raw_headers = bytes[..body_start].to_vec(); let body = &bytes[body_start..]; // TODO: This is a bit tricky as binary data can be compatible with // utf8 encoding. We should check the content-type header // to determine if the content is text or binary if we really want to // be sure. let content = if let Ok(text) = String::from_utf8(body.to_vec()) { Content::Text(text) } else { Content::Binary(body.to_vec()) }; let content_hash = Http1Response::calculate_content_hash(body); let content_length = body.len(); Ok(Http1Response { status_code, headers, raw_headers, content, content_hash, content_length, }) } /// Returns the amount of words in the response content if the /// content is text pub fn word_count(&self) -> usize { if let Content::Text(ref text) = self.content { text.split_whitespace().count() } else { 0 } } /// Returns the amount of lines in the response content if the /// content is text pub fn line_count(&self) -> usize { if let Content::Text(ref text) = self.content { text.lines().count() } else { 0 } } /// Return the status code pub fn status_code(&self) -> u16 { self.status_code } /// Return reference to the content pub fn content(&self) -> &Content { &self.content } /// Calculate the hash of the content /// /// This allows us to quickly compare the response content with other /// responses to see if they are the same. This does NOT /// include the headers in the hash. fn calculate_content_hash(_content: &[u8]) -> String { "TODO".to_string() // What hash function do i want? // - murmur3 hash is used for favicon on shodan and censys // - We use gxhash for our own syn cookie hash function // md5, sha1, sha256, sha512 are all options // use sha2::{Digest, Sha256}; // let mut hasher = Sha256::new(); // hasher.update(content); // format!("{:x}", hasher.finalize()) } } #[cfg(test)] #[allow(clippy::unwrap_used)] #[allow(clippy::panic)] mod tests { use super::*; #[test] fn test_from_bytes_text() { let response_bytes = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nHello, world!".to_vec(); let response = Http1Response::from_bytes(response_bytes).unwrap(); assert_eq!(response.status_code(), 200); assert_eq!( response.headers.get("Content-Type").unwrap(), &["text/plain"] ); if let Content::Text(text) = response.content { assert_eq!(text, "Hello, world!"); } else { panic!("Expected text content"); } assert_eq!(response.content_length, 13); } // #[test] // fn test_from_bytes_binary() { // let response_bytes = // b"HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n\r\n\x01\x02\x03\x04" // .to_vec(); // let response = Http1Response::from_bytes(response_bytes).unwrap(); // assert_eq!(response.status_code, 200); // assert_eq!( // response.headers.get("Content-Type").unwrap(), // "application/octet-stream" // ); // if let Content::Binary(data) = response.content { // assert_eq!(data, vec![1, 2, 3, 4]); // } else { // panic!("Expected binary content"); // } // assert_eq!(response.content_length, 4); // } #[test] fn test_word_count() { let response = Http1Response { status_code: 200, headers: HashMap::new(), raw_headers: Vec::new(), content: Content::Text("Hello world".to_string()), content_hash: String::new(), content_length: 11, }; assert_eq!(response.word_count(), 2); } #[test] fn test_line_count() { let response = Http1Response { status_code: 200, headers: HashMap::new(), raw_headers: Vec::new(), content: Content::Text("Hello\nworld".to_string()), content_hash: String::new(), content_length: 11, }; assert_eq!(response.line_count(), 2); } } ``` ### `scanners/mrpf_http1_scanner/src/models/mod.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/models/mod.rs` ```rust mod http1request; mod http1response; pub use http1request::Http1Request; pub use http1response::Http1Response; ``` ### `scanners/mrpf_http1_scanner/src/progress.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/progress.rs` ```rust use crate::{Http1Request, Http1Response}; use mrpf_engine::ProgressHandler; use std::sync::mpsc::Receiver; /// Progress Messages sent during a HTTPS scan. pub enum Http1Message { // A list of names and the Ip and port they were found on Result((Http1Request, Http1Response)), /// Receive statistics RxStats { packet_count: u32, full_read_buffer_count: u8, }, /// Transmit statistics TxStats { syn_packet_count: u32, packet_count: u32, full_write_buffer_count: u8, }, /// The scan is finished Finished, } /// Headless ProgressHandler for the HTTPS Scanner. /// /// The handler will ensure the scanner is stopped and the results are collected. /// No other progress messages will be parsed since we're running in a headless manner. #[derive(Debug)] pub struct Http1ProgressHandler; impl ProgressHandler for Http1ProgressHandler { type Message = Http1Message; type Results = Vec<(Http1Request, Http1Response)>; fn handler(&mut self, receiver: Receiver) -> Self::Results { let mut results = Vec::with_capacity(16); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { eprintln!("Failed to receive status update: {}", e); break; } }; match message { Http1Message::Result(result) => { results.push(result); } Http1Message::Finished => break, _ => {} } } results } } ``` ### `scanners/mrpf_http1_scanner/src/receive.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/receive.rs` ```rust use super::connections::{TlsConnection, TlsState}; use super::progress::Http1Message; use crate::{Http1Request, Http1Response}; use mrpf_engine::network::arp::handle_arp_requests; use mrpf_engine::network::ip::generate_flow_cookie_v4; use mrpf_engine::network::tcp::{create_tcp_ack, create_tcp_ack_with_data, create_tcp_rst}; use mrpf_engine::{Connection, TcpConnectionTable}; use mrpf_engine::{ScannerConfig, unwrap_or_continue}; use pnet::datalink::DataLinkReceiver; use pnet::packet::Packet; use pnet::packet::ethernet::{EtherTypes, EthernetPacket}; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::tcp::TcpFlags::{ACK, SYN}; use pnet::packet::tcp::TcpPacket; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::{Receiver, Sender}; use std::time::Duration; use tracing::debug; /// The main receive handler for the HTTPS scanner engine. /// /// The function runs in a loop and receives packets from the network interface. /// /// It uses the SYN cookie algorithm to determine if the received packet belongs /// to the scanner. /// /// # Arguments /// /// * `rx` - The data link receiver. /// * `packet_tx` - The sender to transmit packets. /// * `progress_tx` - The sender to transmit progress updates. /// * `config` - The scanner configuration. /// * `https_targets` - The HTTPS targets. pub(crate) fn receive_handler( mut rx: Box, packet_tx: Sender>, progress_tx: Sender, new_connections_rx: Receiver, config: ScannerConfig, stop_flag: Arc, ) { let mut packet_count = 0; let full_read_buffer_count = 0; let progress_interval: Duration = Duration::from_secs(1); // We explicitly set the last progress to half the interval to try and // be out of sync with the transmit thread. This should spread the load // of the progress updates a bit more. let mut last_progress = std::time::Instant::now() - (progress_interval / 2); // Connection Table let mut connection_table = TcpConnectionTable::::new(); // Bytes that will contain outgoing TLS encrypted data // // This will be filled in by the various stages of the RustTLS state machine. // If there is any pending packet that need to be sent from our side, it will be // stored here and our process function will return the length of the packet. let mut outgoing_tls: [u8; 1460] = [0; 1460]; while !stop_flag.load(Ordering::Relaxed) { match rx.next() { Ok(packet) => { // Add any new connections to the connection table while let Ok(new_connection) = new_connections_rx.try_recv() { connection_table.insert(new_connection.next_seq() - 1, new_connection); } // Determine if the received packet belongs to our scanner // TODO: I want to see if I can do ARP spoofing so I can filter packets to the spoofed // ARP address early here to determine it comes from the scanner. This will // allow us to skip the TCP cookie generation for a proper comparison. let ethernet_packet = unwrap_or_continue!(EthernetPacket::new(packet)); let recv_eth_src = ethernet_packet.get_source(); if handle_arp_requests(ðernet_packet, &config, &packet_tx) { // ARP request handled continue; } if ethernet_packet.get_destination() != config.src_mac || ethernet_packet.get_ethertype() != EtherTypes::Ipv4 { // Packet was not destined for us continue; } let ipv4_packet = unwrap_or_continue!(Ipv4Packet::new(ethernet_packet.payload())); let recv_src_ip = ipv4_packet.get_source(); if ipv4_packet.get_destination() != config.src_ip { // This is especially useful if we have a dedicated (e.g. spoofed) IP for the scanner continue; } if ipv4_packet.get_next_level_protocol() != IpNextHeaderProtocols::Tcp { // debug!("Received non-TCP packet"); continue; } let tcp_packet = unwrap_or_continue!(TcpPacket::new(ipv4_packet.payload())); let recv_src_port = tcp_packet.get_source(); let recv_dst_port = tcp_packet.get_destination(); let tcp_flags = tcp_packet.get_flags(); let recv_seq = tcp_packet.get_sequence(); let recv_ack = tcp_packet.get_acknowledgement(); // Calculate the syn cookie to see if we have a new incoming connection // that we need to acknowledge. let syn_cookie = generate_flow_cookie_v4( &config.src_ip, recv_dst_port, &recv_src_ip, recv_src_port, config.entropy, ); let connection = unwrap_or_continue!(connection_table.get(syn_cookie)); if recv_ack == (syn_cookie + 1) && tcp_flags == (SYN | ACK) { // Send ACK packet to complete the handshake let tcp_ack = unwrap_or_continue!(create_tcp_ack( &config.src_mac, &recv_eth_src, &config.src_ip, recv_dst_port, &recv_src_ip, recv_src_port, recv_ack, recv_seq + 1, )); packet_tx .send(tcp_ack.to_vec()) .expect("Failed to send TCP ACK packet"); connection.set_next_seq(recv_ack); connection.set_next_ack(recv_seq + 1); } else { connection.set_next_ack(recv_seq + tcp_packet.payload().len() as u32); } match connection.process(tcp_packet.payload(), &mut outgoing_tls) { TlsState::Send(outgoing_tls_length) => { if let Some(packet) = create_tcp_ack_with_data( &config.src_mac, &recv_eth_src, &config.src_ip, recv_dst_port, &recv_src_ip, recv_src_port, connection.next_seq(), connection.next_ack(), &outgoing_tls[..outgoing_tls_length], ) { packet_tx.send(packet).expect("Failed to send TLS packet"); connection .set_next_seq(connection.next_seq() + outgoing_tls_length as u32); } } TlsState::DoNothing => { // debug!("When don't we have to do anything? I guess when we're waiting on more packets?"); } TlsState::Finish => { progress_tx .send(Http1Message::Result(( Http1Request::new_https( recv_src_ip, connection.payload(), recv_src_port, Some(connection.sni.clone()), ), Http1Response::from_bytes(connection.response.clone()) .expect("failed to serialize response"), ))) .expect("Failed to send result"); let packet = unwrap_or_continue!(create_tcp_rst( &config.src_mac, &recv_eth_src, &config.src_ip, recv_dst_port, &recv_src_ip, recv_src_port, connection.next_seq(), )); packet_tx .send(packet.to_vec()) .expect("Failed to send RST packet"); connection_table.remove(syn_cookie); } TlsState::Error => { debug!("Error processing TLS connection"); connection_table.remove(syn_cookie); } } packet_count += 1; } Err(e) => { panic!("Failed to receive packet: {}", e) } } // Send progress update every second if last_progress.elapsed() > progress_interval { if progress_tx .send(Http1Message::RxStats { packet_count, full_read_buffer_count, }) .is_ok() { packet_count = 0 } last_progress = std::time::Instant::now(); // TODO: Remove any connection that has been open for too long // for (key, connection) in connection_table { // if connection.connection_timed_out(last_progress) { // connection_table.remove(*key); // } // } } } } ``` ### `scanners/mrpf_http1_scanner/src/scanner.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/scanner.rs` ```rust use super::progress::{Http1Message, Http1ProgressHandler}; use super::receive::receive_handler; use super::transmit::transmit_handler; use super::{Template, TemplateIterator, TemplateValue, TemplateVariable}; use crate::{Http1Request, Http1Response}; use mrpf_engine::ProgressHandler; use mrpf_engine::ScannerConfig; use mrpf_engine::network::arp::get_router_arp_address; use mrpf_engine::network::datalink::create_datalink_channel; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc; use std::thread; /// The HTTP 1/x Scanner with TLS 1.2. /// /// It is responsible for managing the network /// interfaces, sending and receiving packets, and handling status updates. /// /// The core of the engine is based on moving the problem of quickly scanning a large volume of /// network hosts from an IO bound to a CPU bound problem. This is done by separating the transmit /// and receive of TCP/IP network stack to two separate threads. /// /// Correlating send and receive packets is being done by SYN cookies instead of relying on sockets /// per connection. pub struct Http1Scanner { /// Scanner Configuration. /// /// This holds configuration data like the source MAC and IP address, the router IP address, /// and the network interface to use. pub config: ScannerConfig, } impl Http1Scanner { /// Initialize engine /// /// # Arguments /// /// * `config` - The scanner configuration. pub fn new(config: ScannerConfig) -> Http1Scanner { Http1Scanner { config } } /// Run a HTTP 1/x scan with the default progress handler protocol. /// /// This will open up the datalink channels, spawn the sending and receiving threads, /// and handle status updates. /// /// # Arguments /// /// * `targets` - The list of targets to scan. /// /// # Returns /// /// The results of the scan. pub fn run( &self, ip_addresses: Ipv4AddressRanges, ports: PortRanges, snis: Vec, template: Template, variables: HashMap>, ) -> Vec<(Http1Request, Http1Response)> { let progress_handler = Http1ProgressHandler; self.run_with_handle( ip_addresses, ports, snis, template, variables, progress_handler, ) } /// Run a TCP Syn Scan with a specific progress handler protocol. /// /// This will open up the datalink channels, spawn the sending and receiving threads, /// and handle status updates. /// /// # Arguments /// /// * `targets` - The list of targets to scan. /// * `progress_handler` - The progress handler function to handle status updates. /// /// # Returns /// /// A JoinHandle containing the results of the scan. pub fn run_with_handle< P: ProgressHandler>, >( &self, ip_addresses: Ipv4AddressRanges, ports: PortRanges, snis: Vec, template: Template, variables: HashMap>, mut progress_handler: P, ) -> Vec<(Http1Request, Http1Response)> { let (mut tx, mut rx) = create_datalink_channel(&self.config.interface, 4096 * 4, 4096 * 4); // Retrieve the router mac address let router_mac = get_router_arp_address( self.config.src_mac, self.config.src_ip, self.config.router_ip, &mut *tx, &mut *rx, ) .expect("Router MAC address should be found"); // Create an atomic flag to signal the thread to stop let stop_flag = Arc::new(AtomicBool::new(false)); let stop_flag_clone = Arc::clone(&stop_flag); // Create a channel and separate thread for handling progress updates let (progress_tx, progress_rx) = mpsc::channel(); let progress_handle = thread::Builder::new() .name("progress_thread".to_string()) .spawn(move || progress_handler.handler(progress_rx)) .expect("Progress thread should have been spawned"); // Create a channel for sending packets from the receive thread to the transmit thread let (packet_tx, packet_rx) = mpsc::channel(); // Create a channel for sending new connections from the transmit thread to the receive thread let (new_connections_tx, new_connections_rx) = mpsc::channel(); // Create a thread for receiving packets let config = self.config.clone(); let progress_tx_clone = progress_tx.clone(); let template_iterator = TemplateIterator::new(template, variables); let receive_handle = thread::Builder::new() .name("receive_thread".to_string()) .spawn(move || { receive_handler( rx, packet_tx, progress_tx_clone, new_connections_rx, config, stop_flag_clone, ); }) .expect("Receive thread should have been spawned"); // Create a thread for sending packets let config = self.config.clone(); let targets = vec![(template_iterator, ip_addresses, ports, snis)]; let transmit_handle = thread::Builder::new() .name("transmit_thread".to_string()) .spawn(move || { transmit_handler( tx, packet_rx, progress_tx, new_connections_tx, config, router_mac, targets, ); }) .expect("transmit thread should have been spawned"); // When the transmit handle is finished, we should stop the receive thread // and return the results from the progress_handle. transmit_handle.join().unwrap(); stop_flag.store(true, Ordering::Relaxed); receive_handle.join().unwrap(); progress_handle.join().unwrap() } } ``` ### `scanners/mrpf_http1_scanner/src/template.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/template.rs` ```rust use aho_corasick::AhoCorasick; use std::collections::HashMap; use std::net::Ipv4Addr; pub type Template = Vec; pub type TemplateVariables = HashMap>; pub type TemplateVariable = Vec; pub type TemplateValue = Vec; #[derive(Debug, Clone)] /// Iterator for generating bytes with template variables /// /// This iterator takes in a body template and a list of variables and will generate a body /// for each combination of variables. /// /// NOTE: The order in which templates are generated will be different between runs. We /// could consider using a BTree instead to maintain order. pub struct TemplateIterator { /// The template with template variables that will be transformed template: Template, /// The list of variables that will be used to transform the template variables: TemplateVariables, /// Keep track of the variables we've used so far variable_idx: Vec, /// Is set to true when we've generated all possible combinations done: bool, } impl TemplateIterator { /// Create a new TemplateIterator pub fn new(template: Template, variables: TemplateVariables) -> Self { // Filter out variables without values let variables: TemplateVariables = variables .into_iter() .filter(|(_, values)| !values.is_empty()) .collect(); // variable_idx is a vector of the same length as amount of variables, with each // index keeping track of the current value for that variable. let variable_idx = vec![0; variables.len()]; for variable in variables.keys() { if !variable.starts_with(b"${") || !variable.ends_with(b"}") { panic!( "Template variable must start with '${{' and end with '}}', found {:?} instead", variable ); } } TemplateIterator { template, variables, variable_idx, done: false, } } // Function to retrieve the total amount of combinations that will be generated // NOTE: This is actually not accurate as we could have nested variables // pub fn len(&self) -> usize { // self.variables.values().map(|values| values.len()).product() // } /// Check if the iterator has no variables to iterate over pub fn is_empty(&self) -> bool { self.variables.is_empty() } } impl TemplateIterator { /// Add variables of a IP/Port/SNI to the existing variables /// /// NOTE: This will reset the iterator to the first combination /// /// # Arguments /// /// * `ip` - The IP address of the endpoint, will be stored as ${IP} /// * `port` - The port of the endpoint, will be stored as ${PORT} /// * `sni` - The SNI of the endpoint, will be stored as ${SNI} if not None pub fn add_host_variables(&mut self, ip: Ipv4Addr, port: u16, sni: Option) { self.variables .insert(b"${IP}".to_vec(), vec![ip.to_string().into()]); self.variables.insert( b"${PORT}".to_vec(), vec![format!("{}", port).as_bytes().to_vec()], ); if let Some(sni) = sni { self.variables .insert(b"${SNI}".to_vec(), vec![sni.as_bytes().to_vec()]); } // Ensure to update the variable_idx to match the new variables self.variable_idx = vec![0; self.variables.len()]; self.done = false; } } impl Iterator for TemplateIterator { type Item = Vec; fn next(&mut self) -> Option { if self.done { return None; } if self.variables.is_empty() { self.done = true; return Some(self.template.clone()); } // Recursively loop through the template to replace any variables. // Since each variable value could contain another variable, we need to loop // until no more replacements are made. let mut transformed_body = self.template.clone(); loop { let mut replacements = HashMap::new(); for (variable_name, values) in self.variables.iter() { // Fetch the current value for this variable let value = &values[self.variable_idx[self .variables .keys() .position(|k| k == variable_name) .unwrap()]]; // Collect the variable names and values into a HashMap replacements.insert(variable_name.clone(), value.clone()); } // Call the replace_variables function let new_transformed_body = replace_variables(&transformed_body, &replacements); // Check if any replacements were made if new_transformed_body == transformed_body { break; } transformed_body = new_transformed_body; } // Update the indices to generate the next combination for i in (0..self.variable_idx.len()).rev() { self.variable_idx[i] += 1; if self.variable_idx[i] < self.variables.values().nth(i).unwrap().len() { break; } self.variable_idx[i] = 0; if i == 0 { self.done = true; } } Some(transformed_body) } } /// Replace variables in a body with the given replacements fn replace_variables(body: &[u8], replacements: &HashMap, Vec>) -> Vec { // Create an Aho-Corasick automaton with all the variable names let patterns: Vec<&[u8]> = replacements.keys().map(|k| k.as_slice()).collect(); let ac = AhoCorasick::new(&patterns).expect("Failed to create Aho-Corasick automaton"); // Create a mutable buffer to store the transformed body let mut transformed_body = Vec::with_capacity(body.len()); // Iterate through the matches and perform the replacements let mut last_match_end = 0; for mat in ac.find_iter(body) { // Append the text before the match transformed_body.extend_from_slice(&body[last_match_end..mat.start()]); // Append the replacement value let pattern = patterns[mat.pattern()]; if let Some(replacement) = replacements.get(pattern) { transformed_body.extend_from_slice(replacement); } // Update the last match end position last_match_end = mat.end(); } // Append the remaining text after the last match transformed_body.extend_from_slice(&body[last_match_end..]); transformed_body } #[cfg(test)] mod tests { use super::*; #[test] fn test_single_variable() { let template = b"Hello, ${name}!".to_vec(); let mut variables = HashMap::new(); variables.insert(b"${name}".to_vec(), vec![b"World".to_vec()]); let mut iter = TemplateIterator::new(template, variables); assert_eq!(iter.next(), Some(b"Hello, World!".to_vec())); assert_eq!(iter.next(), None); } #[test] fn test_multiple_variables() { let template = b"Hello, ${name} from ${place}!".to_vec(); let mut variables = HashMap::new(); variables.insert( b"${name}".to_vec(), vec![b"Alice".to_vec(), b"Bob".to_vec()], ); variables.insert( b"${place}".to_vec(), vec![b"Wonderland".to_vec(), b"Builderland".to_vec()], ); let iter = TemplateIterator::new(template, variables); let mut expected_values = vec![ b"Hello, Alice from Wonderland!".to_vec(), b"Hello, Alice from Builderland!".to_vec(), b"Hello, Bob from Wonderland!".to_vec(), b"Hello, Bob from Builderland!".to_vec(), ]; for value in iter { assert!(expected_values.contains(&value)); expected_values.retain(|x| x != &value); } assert!(expected_values.is_empty()); } #[test] fn test_multiple_variables_with_uneven_value_count() { let template = b"Hello, ${name} from ${place}!".to_vec(); let mut variables = HashMap::new(); variables.insert(b"${name}".to_vec(), vec![b"Alice".to_vec()]); variables.insert( b"${place}".to_vec(), vec![b"Wonderland".to_vec(), b"Builderland".to_vec()], ); let iter = TemplateIterator::new(template, variables); let mut expected_values = vec![ b"Hello, Alice from Wonderland!".to_vec(), b"Hello, Alice from Builderland!".to_vec(), ]; for value in iter { assert!(expected_values.contains(&value)); expected_values.retain(|x| x != &value); } assert!(expected_values.is_empty()); } #[test] fn test_add_host_variables() { let template = b"Connecting to ${IP}:${PORT} with SNI ${SNI}".to_vec(); let variables = HashMap::new(); let mut iter = TemplateIterator::new(template, variables); iter.add_host_variables( Ipv4Addr::new(127, 0, 0, 1), 8080, Some("www.mortimer.nl".to_string()), ); assert_eq!( iter.next(), Some(b"Connecting to 127.0.0.1:8080 with SNI www.mortimer.nl".to_vec()) ); assert_eq!(iter.next(), None); } #[test] fn test_no_variables() { let template = b"Hello, World!".to_vec(); let variables = HashMap::new(); let mut iter = TemplateIterator::new(template, variables); assert_eq!(iter.next(), Some(b"Hello, World!".to_vec())); assert_eq!(iter.next(), None); } #[test] fn test_empty_template() { let template = b"".to_vec(); let mut variables = HashMap::new(); variables.insert(b"${var}".to_vec(), vec![b"value".to_vec()]); let mut iter = TemplateIterator::new(template, variables); assert_eq!(iter.next(), Some(b"".to_vec())); assert_eq!(iter.next(), None); } #[test] fn test_nested_variables() { let template = b"Hello, ${greeting}!".to_vec(); let mut variables = HashMap::new(); variables.insert( b"${greeting}".to_vec(), vec![b"${name} from ${place}".to_vec()], ); variables.insert(b"${name}".to_vec(), vec![b"Alice".to_vec()]); variables.insert(b"${place}".to_vec(), vec![b"Wonderland".to_vec()]); let mut iter = TemplateIterator::new(template, variables); assert_eq!(iter.next(), Some(b"Hello, Alice from Wonderland!".to_vec())); assert_eq!(iter.next(), None); } } ``` ### `scanners/mrpf_http1_scanner/src/tls.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/tls.rs` ```rust use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::{ClientConfig, Error, RootCertStore}; use std::sync::{Arc, Once}; /// A Rustls server certificate verifier that accepts any certificate. /// /// TAKEN FROM https://github.com/hatoo/oha/blob/6c4a0a515363f8039bbf94e009d4511feb9f8a15/src/client.rs#L713 /// also another example in the test here: https://github.com/jbg/tokio-postgres-rustls/blob/master/src/lib.rs #[derive(Debug)] pub struct AcceptAnyServerCert; impl ServerCertVerifier for AcceptAnyServerCert { fn verify_server_cert( &self, _end_entity: &rustls_pki_types::CertificateDer<'_>, _intermediates: &[rustls_pki_types::CertificateDer<'_>], _server_name: &rustls_pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: rustls_pki_types::UnixTime, ) -> Result { Ok(ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &rustls_pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &rustls_pki_types::CertificateDer<'_>, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { rustls::crypto::CryptoProvider::get_default() .unwrap() .signature_verification_algorithms .supported_schemes() } } static mut TLS_CLIENT_CONFIG: Option> = None; static INIT: Once = Once::new(); /// Get a Rustls TLS client configuration that accepts any server certificate. /// /// Calling this function multiple times will return the same configuration. #[expect(static_mut_refs)] pub(crate) fn get_tls_client_config() -> Arc { unsafe { INIT.call_once(|| { let mut config = ClientConfig::builder() .with_root_certificates(RootCertStore::empty()) .with_no_client_auth(); config .dangerous() .set_certificate_verifier(Arc::new(AcceptAnyServerCert)); TLS_CLIENT_CONFIG = Some(Arc::new(config)); }); TLS_CLIENT_CONFIG.clone().unwrap() } } ``` ### `scanners/mrpf_http1_scanner/src/transmit.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_http1_scanner/src/transmit.rs` ```rust use super::TemplateIterator; use super::connections::TlsConnection; use super::progress::Http1Message; use mrpf_engine::ScannerConfig; use mrpf_engine::network::datalink::send_packet; use mrpf_engine::network::tcp::create_syn_tcp_packet_v4; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use pnet::datalink::DataLinkSender; use pnet::util::MacAddr; use rand::Rng; use std::sync::mpsc::{Receiver, Sender}; use std::thread::sleep; use std::time::{Duration, Instant}; use tracing::info; // Interval of status updates and reset of our packet counts const PROGRESS_INTERVAL: Duration = Duration::from_secs(1); /// The main function responsible for sending packets. /// /// To quote Masscan: /// /// ```quote /// THIS IS WHERE ALL THE EXCITEMENT HAPPENS!!!! /// 90% of CPU cycles are in this function. /// ``` pub(crate) fn transmit_handler( mut tx: Box, packet_rx: Receiver>, progress_tx: Sender, new_connections_tx: Sender, mut config: ScannerConfig, // TODO move router mac to the builder? Then we only have to do it once // instead of each scan and it makes it easier to pass on values. router_mac: MacAddr, targets: Vec<(TemplateIterator, Ipv4AddressRanges, PortRanges, Vec)>, ) { // Progress counters let mut syn_packet_per_second: u32 = 0; let mut packets_per_second: u32 = 0; let mut full_write_buffer_count = 0; let mut last_progress = Instant::now() - PROGRESS_INTERVAL; // Start rate limiter config.rate_limiter.start(); // Windows uses 49152 to 655352, linux uses 32768 to 609991, so best to generate a // source port between 49152 and 609991 to keep them guessing. // However, if we're not able to spoof our source IP address, macOS or Linux will // send RST packets for TCP connections established by us. These packets will // interfere with our scanning. To avoid this you can create a firewall rule // to drop outbound RST packets for TCP connections initiated from the OS source port range. // This is not optimal to run on your own system but acceptable for dedicated scanning systems. let mut rng = rand::rng(); // Parse EXCLUDE_OUTBOUND_TCP_PORTS env var ("low-high") into RangeInclusive, fallback to default let src_port_range = std::env::var("EXCLUDE_OUTBOUND_TCP_PORTS") .ok() .and_then(|s| { let (low, high) = s.split_once('-')?; let low: u16 = low.parse().ok()?; let high: u16 = high.parse().ok()?; if low <= high { // All a bit hacky, but we want to exclude the range from source ports Some(10000..=low - 1) } else { None } }) .unwrap_or(32768..=65535); info!(src_port_range = ?src_port_range, "Source port range for scanning"); for (template_iterator, ipv4_addresses, ports, snis) in targets { for template in template_iterator { for sni in snis.clone() { for dst_port in ports.clone() { let src_port: u16 = rng.random_range(src_port_range.clone()); for dst_ip in ipv4_addresses.clone() { // Transmit pending packets from the packet receiver. // // These packets are generated by the receive thread. They should get precedence over // normal SYN packets, otherwise remote servers might start to retransmit or drop connections. for packet in packet_rx.try_iter() { config.rate_limiter.wait_for_token(); send_packet(&mut tx, &packet, &mut full_write_buffer_count); packets_per_second += 1; } // Establish new TCP connection let tcp_syn_packet = create_syn_tcp_packet_v4( &config, &router_mac, src_port, &dst_ip.into(), dst_port, ); if let Some((syn_cookie, packet)) = tcp_syn_packet { config.rate_limiter.wait_for_token(); // Create a new connection and send it to the receive thread // so it can track connection status let connection = TlsConnection::new( sni.clone(), template.clone(), syn_cookie + 1, syn_cookie + 1, ); new_connections_tx .send(connection) .expect("Send new connection"); // Send the initial SYN packet to the remote host send_packet(&mut tx, &packet, &mut full_write_buffer_count); syn_packet_per_second += 1; packets_per_second += 1; } // Send status updates to our process handler and reset counters. if last_progress.elapsed() > PROGRESS_INTERVAL { send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); last_progress = Instant::now(); } } } } } } // Update queued up progress statistics send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); // Process any queued packets from the receive thread for the duration // of the connection timeout. let start_time = Instant::now(); while start_time.elapsed() < config.connection_timeout { for packet in packet_rx.try_iter() { send_packet(&mut tx, &packet, &mut full_write_buffer_count); packets_per_second += 1; } sleep(Duration::from_millis(10)); } // Send final statistic update send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); // Scan is finished, signal our process handler progress_tx .send(Http1Message::Finished) .expect("Send finished message"); } /// Send transmit progress update #[inline] fn send_progress_update( progress_tx: Sender, syn_packet_per_second: &mut u32, packets_per_second: &mut u32, full_write_buffer_count: &mut u8, ) { progress_tx .send(Http1Message::TxStats { syn_packet_count: *syn_packet_per_second, packet_count: *packets_per_second, full_write_buffer_count: *full_write_buffer_count, }) .unwrap(); *syn_packet_per_second = 0; *packets_per_second = 0; *full_write_buffer_count = 0; } ``` ### `scanners/mrpf_tcp_syn_scanner/src/lib.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tcp_syn_scanner/src/lib.rs` ```rust pub mod progress; mod receive; pub mod scanner; mod transmit; pub use progress::TcpSynMessage; pub use scanner::TcpSynScanner; ``` ### `scanners/mrpf_tcp_syn_scanner/src/progress.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tcp_syn_scanner/src/progress.rs` ```rust use mrpf_engine::ProgressHandler; use std::collections::{HashMap, HashSet}; use std::net::Ipv4Addr; use std::sync::mpsc::Receiver; /// Progress Messages sent during a TCP SYN scan. pub enum TcpSynMessage { // A found TCP Open Port and the corresponding IP address Result((Ipv4Addr, u16)), /// Receive statistics RxStats { packet_count: u32, full_read_buffer_count: u8, }, /// Transmit statistics TxStats { syn_packet_count: u32, packet_count: u32, full_write_buffer_count: u8, }, /// The scan is finished Finished, } /// Headless ProgressHandler for the TCP SYN Scanner. /// /// The handler will ensure the scanner is stopped and the results are collected. /// No other progress messages will be parsed since we're running in a headless manner. #[derive(Debug)] pub struct TcpSynProgressHandler; impl ProgressHandler for TcpSynProgressHandler { type Message = TcpSynMessage; type Results = HashMap>; fn handler(&mut self, receiver: Receiver) -> Self::Results { let mut results = HashMap::>::new(); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { eprintln!("Failed to receive status update: {}", e); break; } }; match message { TcpSynMessage::Result((ip, port)) => { results.entry(ip).or_default().insert(port); } TcpSynMessage::Finished => break, _ => {} } } results } } ``` ### `scanners/mrpf_tcp_syn_scanner/src/receive.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tcp_syn_scanner/src/receive.rs` ```rust use super::progress::TcpSynMessage; use mrpf_engine::network::arp::handle_arp_requests; use mrpf_engine::network::ip::generate_flow_cookie_v4; use mrpf_engine::network::tcp::create_tcp_rst; use mrpf_engine::{ScannerConfig, unwrap_or_continue}; use pnet::datalink::DataLinkReceiver; use pnet::packet::Packet; use pnet::packet::ethernet::{EtherTypes, EthernetPacket}; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::tcp::TcpFlags::{ACK, SYN}; use pnet::packet::tcp::TcpPacket; use std::sync::mpsc::Sender; use std::time::Duration; use tracing::debug; /// The main receive handler for the network engine. /// /// The function runs in a loop and receives packets from the network interface. /// /// It uses the SYN cookie algorithm to determine if the received packet belongs /// to the scanner. pub(crate) fn receive_handler( mut rx: Box, packet_tx: Sender>, progress_tx: Sender, config: ScannerConfig, ) { let mut packet_count = 0; let full_read_buffer_count = 0; let progress_interval: Duration = Duration::from_secs(1); // We explicitly set the last progress to half the interval to try and // be out of sync with the transmit thread. This should spread the load // of the progress updates a bit more. let mut last_progress = std::time::Instant::now() - (progress_interval / 2); loop { match rx.next() { Ok(packet) => { // Determine if the received packet belongs to our scanner // TODO: I want to see if I can do ARP spoofing so I can filter packets to the spoofed // ARP address early here to determine it comes from the scanner. This will // allow us to skip the TCP cookie generation for a proper comparison. let ethernet_packet = unwrap_or_continue!(EthernetPacket::new(packet)); if handle_arp_requests(ðernet_packet, &config, &packet_tx) { // ARP request handled continue; } if ethernet_packet.get_destination() != config.src_mac || ethernet_packet.get_ethertype() != EtherTypes::Ipv4 { // Packet was not destined for us continue; } let ipv4_packet = unwrap_or_continue!(Ipv4Packet::new(ethernet_packet.payload())); if ipv4_packet.get_destination() != config.src_ip { // This is especially useful if we have a dedicated (e.g. spoofed) IP for the scanner continue; } if ipv4_packet.get_next_level_protocol() != IpNextHeaderProtocols::Tcp { // debug!("Received non-TCP packet"); continue; } let tcp_packet = unwrap_or_continue!(TcpPacket::new(ipv4_packet.payload())); // Calculate our original syn cookie and see if we have a match let syn_cookie_ack = generate_flow_cookie_v4( &ipv4_packet.get_destination(), tcp_packet.get_destination(), &ipv4_packet.get_source(), tcp_packet.get_source(), config.entropy, ); if tcp_packet.get_acknowledgement() == (syn_cookie_ack + 1) && tcp_packet.get_flags() & (SYN | ACK) == (SYN | ACK) { // Explicitly send a RST packet to close the connection // // This avoids the destination from retrying the SYN/ACK // and decreases the incoming traffic. // // TODO: Not sure if the RST is really received or processed by the destination? // It seems we're still getting new SYN/ACK packets after sending the RST. // Perhaps I'm messing up the ack/seq numbers? let tcp_rst = unwrap_or_continue!(create_tcp_rst( &config.src_mac, ðernet_packet.get_source(), &config.src_ip, tcp_packet.get_destination(), &ipv4_packet.get_source(), tcp_packet.get_source(), syn_cookie_ack + 1, )); packet_tx .send(tcp_rst.to_vec()) .expect("Failed to send TCP RST packet"); // Open port found debug!( ipv4_packet.get_source = ?ipv4_packet.get_source(), tcp_packet.get_source = ?tcp_packet.get_source(), "Open port found" ); progress_tx .send(TcpSynMessage::Result(( ipv4_packet.get_source(), tcp_packet.get_source(), ))) .unwrap(); } packet_count += 1; } Err(e) => { panic!("Failed to receive packet: {}", e) } } // Send progress update every second if last_progress.elapsed() > progress_interval { if progress_tx .send(TcpSynMessage::RxStats { packet_count, full_read_buffer_count, }) .is_ok() { packet_count = 0 } last_progress = std::time::Instant::now(); } } } ``` ### `scanners/mrpf_tcp_syn_scanner/src/scanner.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tcp_syn_scanner/src/scanner.rs` ```rust use super::progress::{TcpSynMessage, TcpSynProgressHandler}; use super::receive::receive_handler; use super::transmit::transmit_handler; use mrpf_engine::ProgressHandler; use mrpf_engine::ScannerConfig; use mrpf_engine::network::arp::get_router_arp_address; use mrpf_engine::network::datalink::create_datalink_channel; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use std::collections::{HashMap, HashSet}; use std::net::Ipv4Addr; use std::sync::mpsc; use std::thread; /// The TCP Syn Scanner. /// /// It is responsible for managing the network /// interfaces, sending and receiving packets, and handling status updates. /// /// The core of the engine is based on moving the problem of quickly scanning a large volume of /// network hosts from an IO bound to a CPU bound problem. This is done by separating the transmit /// and receive of TCP/IP network stack to two separate threads. /// /// Correlating send and receive packets is being done by SYN cookies instead of relying on sockets /// per connection. /// /// TODO: /// Spoofing MAC still doesn't work. The router doesn't seem to respond to my arp request for router address /// we're only getting it implicitly after a while due to the thing broadcasting itself. /// /// *TODO* This should become only public inside the crate. Our consumers should /// instead be using the various scanner implementations. /// /// # Example /// /// ```ignore /// use mrpf::network::datalink::get_interface; /// use mrpf::{ScannerConfig, TcpSynScanner}; /// use std::net::Ipv4Addr; /// /// let interface = get_interface(None).unwrap(); /// let src_mac = interface.mac.unwrap(); /// /// let config = ScannerConfig { /// interface, /// src_ip: Ipv4Addr::new(192, 168, 1, 182), /// src_mac, /// router_ip: Ipv4Addr::new(192, 168, 1, 1), /// }; /// /// let scanner = TcpSynScanner::new(config); /// let ipv4_addresses = vec![Ipv4Addr::new(192, 168, 1, 100)]; /// let ports = vec![80, 443]; /// /// let results = scanner.run(ipv4_addresses, ports); /// for (ip, open_ports) in results { /// debug!("IP: {}, Open Ports: {:?}", ip, open_ports); /// } /// ``` pub struct TcpSynScanner { /// Scanner Configuration. /// /// This holds configuration data like the source MAC and IP address, the router IP address, /// and the network interface to use. pub config: ScannerConfig, } impl TcpSynScanner { /// Initialize engine pub fn new(config: ScannerConfig) -> TcpSynScanner { // Mac spoofing still not working, perhaps it's something to do with // unicast/multicat + local/manufactor assigned code? this is encoded // in the first and second significant bits of the first octet // let src_mac: MacAddr = MacAddr::new(0x00, 0xad, 0xde, 0xad, 0xde, 0xad); TcpSynScanner { config } } /// Run a TCP Syn Scan with the default progress handler protocol. /// /// This will open up the datalink channels, spawn the sending and receiving threads, /// and handle status updates. /// /// # Arguments /// /// * `ipv4_addresses` - The list of IPv4 addresses to scan. /// * `ports` - The list of ports to scan. /// /// # Returns /// /// A HashMap containing the results of the scan. pub fn run( &self, ipv4_addresses: Ipv4AddressRanges, ports: PortRanges, ) -> HashMap> { let progress_handler = TcpSynProgressHandler; self.run_with_handle(ipv4_addresses, ports, progress_handler) } /// Run a TCP Syn Scan with a specific progress handler protocol. /// /// This will open up the datalink channels, spawn the sending and receiving threads, /// and handle status updates. /// /// # Arguments /// /// * `ipv4_addresses` - The list of IPv4 addresses to scan. /// * `ports` - The list of ports to scan. /// * `progress_handler` - The progress handler function to handle status updates. /// /// # Returns /// /// A JoinHandle containing the results of the scan. pub fn run_with_handle< P: ProgressHandler>>, >( &self, ipv4_addresses: Ipv4AddressRanges, ports: PortRanges, mut progress_handler: P, ) -> HashMap> { let (mut tx, mut rx) = create_datalink_channel(&self.config.interface, 4096 * 4, 4096 * 4); // Retrieve the router mac address let router_mac = get_router_arp_address( self.config.src_mac, self.config.src_ip, self.config.router_ip, &mut *tx, &mut *rx, ) .expect("Router MAC address should be found"); // Create a channel and separate thread for handling progress updates let (progress_tx, progress_rx) = mpsc::channel(); let progress_handle = thread::Builder::new() .name("progress_thread".to_string()) .spawn(move || progress_handler.handler(progress_rx)) .expect("Progress thread should have been spawned"); // Create a channel for sending packets from the receive thread to the transmit thread let (packet_tx, packet_rx) = mpsc::channel(); // Create a thread for receiving packets let config = self.config.clone(); let progress_tx_clone = progress_tx.clone(); let receive_handle = thread::Builder::new() .name("receive_thread".to_string()) .spawn(move || { receive_handler(rx, packet_tx, progress_tx_clone, config); }) .expect("Receive thread should have been spawned"); // Create a thread for sending packets let config = self.config.clone(); let transmit_handle = thread::Builder::new() .name("transmit_thread".to_string()) .spawn(move || { transmit_handler( tx, packet_rx, progress_tx, config, router_mac, ipv4_addresses, ports, ); }) .expect("transmit thread should have been spawned"); // When the transmit handle is finished, we can drop the receive handle // and return the results from the progress_handle. transmit_handle.join().unwrap(); // TODO: This actually doesn't stop the receive thread. We will have to stop the thread by // passing a AtomicBool Arc to the transmit and receive thread so we can signal it's done. // There might be other control flags we'd like to set such as pausing the loop, changing // the rate limiting settings, etc. Maybe we can have a shared struct that the progress, // transmit and receive threads use to communicate with each other. drop(receive_handle); progress_handle.join().unwrap() } } ``` ### `scanners/mrpf_tcp_syn_scanner/src/transmit.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tcp_syn_scanner/src/transmit.rs` ```rust use super::progress::TcpSynMessage; use mrpf_engine::ScannerConfig; use mrpf_engine::network::datalink::send_packet; use mrpf_engine::network::tcp::create_syn_tcp_packet_v4; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use pnet::datalink::DataLinkSender; use pnet::util::MacAddr; use rand::Rng; use std::sync::mpsc::{Receiver, Sender}; use std::thread::sleep; use std::time::{Duration, Instant}; use tracing::info; // Interval of status updates and reset of our packet counts const PROGRESS_INTERVAL: Duration = Duration::from_secs(1); /// The main function responsible for sending packets. /// /// To quote Masscan: /// /// ```quote /// THIS IS WHERE ALL THE EXCITEMENT HAPPENS!!!! /// 90% of CPU cycles are in this function. /// ``` pub(crate) fn transmit_handler( mut tx: Box, packet_rx: Receiver>, progress_tx: Sender, mut config: ScannerConfig, // TODO move router mac to the builder? Then we only have to do it once // instead of each scan and it makes it easier to pass on values. router_mac: MacAddr, ipv4_addresses: Ipv4AddressRanges, ports: PortRanges, ) { // Progress counters let mut syn_packet_per_second: u32 = 0; let mut packets_per_second: u32 = 0; let mut full_write_buffer_count = 0; let mut last_progress = Instant::now() - PROGRESS_INTERVAL; // Start rate limiter config.rate_limiter.start(); // Windows uses 49152 to 655352, linux uses 32768 to 609991, so best to generate a // source port between 49152 and 609991 to keep them guessing. // However, if we're not able to spoof our source IP address, macOS or Linux will // send RST packets for TCP connections established by us. These packets will // interfere with our scanning. To avoid this you can create a firewall rule // to drop outbound RST packets for TCP connections initiated from the OS source port range. // This is not optimal to run on your own system but acceptable for dedicated scanning systems. let mut rng = rand::rng(); // Parse EXCLUDE_OUTBOUND_TCP_PORTS env var ("low-high") into RangeInclusive, fallback to default let src_port_range = std::env::var("EXCLUDE_OUTBOUND_TCP_PORTS") .ok() .and_then(|s| { let (low, high) = s.split_once('-')?; let low: u16 = low.parse().ok()?; let high: u16 = high.parse().ok()?; if low <= high { // All a bit hacky, but we want to exclude the range from source ports Some(10000..=low - 1) } else { None } }) .unwrap_or(32768..=65535); info!(src_port_range = ?src_port_range, "Source port range for scanning"); for dst_port in ports { let src_port: u16 = rng.random_range(src_port_range.clone()); for dst_ip in ipv4_addresses.clone() { // Transmit pending packets from the packet receiver. // // These packets are generated by the receive thread. They should get precedence over // normal SYN packets, otherwise remote servers might start to retransmit or drop connections. for packet in packet_rx.try_iter() { config.rate_limiter.wait_for_token(); send_packet(&mut tx, &packet, &mut full_write_buffer_count); packets_per_second += 1; } // Establish new TCP connection let tcp_syn_packet = create_syn_tcp_packet_v4(&config, &router_mac, src_port, &dst_ip.into(), dst_port); if let Some((_, packet)) = tcp_syn_packet { config.rate_limiter.wait_for_token(); send_packet(&mut tx, &packet, &mut full_write_buffer_count); syn_packet_per_second += 1; packets_per_second += 1; } // Send status updates to our process handler and reset counters. if last_progress.elapsed() > PROGRESS_INTERVAL { send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); last_progress = Instant::now(); } } } // Update queued up progress statistics send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); // Process any queued packets from the receive thread for the duration // of the connection timeout. let start_time = Instant::now(); while start_time.elapsed() < config.connection_timeout { for packet in packet_rx.try_iter() { send_packet(&mut tx, &packet, &mut full_write_buffer_count); packets_per_second += 1; } sleep(Duration::from_millis(10)); } // Send final statistic update send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); // Scan is finished, signal our process handler progress_tx .send(TcpSynMessage::Finished) .expect("Send finished message"); } /// Send transmit progress update #[inline] fn send_progress_update( progress_tx: Sender, syn_packet_per_second: &mut u32, packets_per_second: &mut u32, full_write_buffer_count: &mut u8, ) { progress_tx .send(TcpSynMessage::TxStats { syn_packet_count: *syn_packet_per_second, packet_count: *packets_per_second, full_write_buffer_count: *full_write_buffer_count, }) .unwrap(); *syn_packet_per_second = 0; *packets_per_second = 0; *full_write_buffer_count = 0; } ``` ### `scanners/mrpf_tls_scanner/src/how_to_manage_snis.md` **Absolute path:** `/root/mrpf/scanners/mrpf_tls_scanner/src/how_to_manage_snis.md` ```markdown # my tls scanner needs to be a vector<(SNI, Ports, Ipv4Addresses)> This will provide the most flexibility. We can either scan a large range of IP addresses for a single SNI to find internal hosts, wildcards or other domains not directly. , or scrape the SAN certs from a whole bunch of known Host + their IP address (e.g. feed known domains into Tls Scanner to find more domains) This is the right input indeed but the implementation is more tricky. We can't just iterate over this in the transmit queue as it's actually the receive queue that has to construct the SNI. Thats what you get when separating transmit and receive in different queues :). If we can find a very robust pattern for this, it would open up a really nice network stack that we can use for other protocols. For instance, I will have the same issue with HTTP scanner. We somehow have to maintain state inside of the receive thread to track what hosts we've tackled already. If we're not careful with cleaning up, this can become quite memory intensive. It would be so great if we could smuggle the data somehow into the initial TCP syn packet so we know what we're dealing with? Perhaps we should use the syn cookie to track this? So instead of dst_ip and port, we'll feed in dst_ip, port and sni? This means we would have to build a shadow table that contains all ip/port/sni hashes though which is probably VERY big when we're scanning large ranges? ## Example calculation We have 10,000,000 IP addresses with 100 ports and 1 hostname. Each IP address + port + hostname pair is represented by a u32 value, which is 4 bytes. To calculate the total memory usage: Number of pairs = 10,000,000 IP addresses * 100 ports * 1 hostname = 1,000,000,000 pairs Total memory usage = 1,000,000,000 pairs * 4 bytes per pair = 4,000,000,000 bytes = 4,000,000 KB = 4,000 MB = 4 GB So, the total memory usage would be 4 GB. ## What are my use cases? - Scan large ranges of ip/ports for a single sni (No table is needed as we can just iterate over the ranges) - Scan a single ip/port for multiple sni's (A table is needed) - Scan large ranges of ip/ports for multiple sni's (A table is needed) ## Update to the input Initially i thought the input should be a Vec<(SNI, Ipv4Addresses, Ports)>. However, that would be sub-optimal given the following case. As you can see you will need to clone the ip addresses and ports for each SNI. It's also more tedious to write out. ```rust let targets = vec![ ( "example.com".to_string(), ipv4_addresses.clone(), ports.clone(), ), ("mortimer.nl".to_string(), ipv4_addresses, ports), ]; ``` The better structure would be a Vec<(Vec, Ipv4Addresses, Ports)>: ```rust let targets = vec![ ( vec!["example.com", "mortimer.nl"].into(), ipv4_addresses, ports, ) ]; ``` ## TODO: A big optimization for memory would be to split the TLS scans into a single scan per SNI Then we don't have to maintain the full SNI table in memory at all!! It would move more of the logit to the run_with_handle function as we will need to start multiple scans there. It would probably also require some changes to the progress monitor. Doing this should not have any real impact on the scanning itself as we're still just spitting out packets as fast as we can push them in the transmit queue. We won't be able to optimize further by keeping the TCP connection open and handling multiple SNI in the same TCP connection though... Thats probably too difficult anyway and suspect some servers won't like this either (although it would be a nice avenue to investigate if you can bypass some host based routing on load balancers... something for another day). However, what if we have a large list of unique SNIs that belong to a single ip/port tuple? For instance, if we want to fetch all known active domains and get their SAN certs. It will mean we will start a new scan for EACH port. So that means building up new threads + waiting for the connect timeout after each scan. ## Optimization that probably will make it too difficult? What would be nice is if we could keep the TCP connection open when we're fuzzing for multiple SNI's in the same connection. However, this would perhaps be very difficult to fit into the same receive function that fuzzes over ranges of ports and ip addresses? ``` ### `scanners/mrpf_tls_scanner/src/lib.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tls_scanner/src/lib.rs` ```rust pub mod progress; mod receive; pub mod scanner; pub(crate) mod targets; mod transmit; pub use progress::TlsMessage; pub use scanner::TlsScanner; ``` ### `scanners/mrpf_tls_scanner/src/progress.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tls_scanner/src/progress.rs` ```rust use mrpf_engine::ProgressHandler; use std::collections::{HashMap, HashSet}; use std::net::Ipv4Addr; use std::sync::mpsc::Receiver; use tracing::{debug, error}; /// Progress Messages sent during a TLS scan. pub enum TlsMessage { // A list of names and the Ip and port they were found on Result((HashSet, Ipv4Addr, u16)), /// Receive statistics RxStats { packet_count: u32, full_read_buffer_count: u8, }, /// Transmit statistics TxStats { syn_packet_count: u32, packet_count: u32, full_write_buffer_count: u8, }, /// The scan is finished Finished, } /// Headless ProgressHandler for the TLS SNI Scanner. /// /// The handler will ensure the scanner is stopped and the results are collected. /// No other progress messages will be parsed since we're running in a headless manner. #[derive(Debug)] pub struct TlsProgressHandler; impl ProgressHandler for TlsProgressHandler { type Message = TlsMessage; type Results = HashMap>; fn handler(&mut self, receiver: Receiver) -> Self::Results { let mut results = HashMap::>::new(); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { error!("TLS ProgressHandler receiver error: {}", e); break; } }; match message { TlsMessage::Result((names, ip, port)) => { debug!(names = ?names, ip = %ip, port = %port, "Received TLS scan result"); for name in names { results.entry(name).or_default().insert((ip, port)); } } TlsMessage::Finished => break, _ => {} } } results } } ``` ### `scanners/mrpf_tls_scanner/src/receive.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tls_scanner/src/receive.rs` ```rust use super::progress::TlsMessage; use super::targets::TlsTargets; use mrpf_cert_parser::{clean_fqdn, extract_cn_and_san}; use mrpf_engine::network::arp::handle_arp_requests; use mrpf_engine::network::ip::generate_flow_cookie_v4; use mrpf_engine::network::tcp::{create_tcp_ack, create_tcp_rst}; use mrpf_engine::network::tls::extract_tls12_leaf_certificate; use mrpf_engine::network::tls::{TlsExtractError, create_tls_client_hello}; use mrpf_engine::{Connection, TcpConnection, TcpConnectionTable}; use mrpf_engine::{ScannerConfig, unwrap_or_continue}; use pnet::datalink::DataLinkReceiver; use pnet::packet::Packet; use pnet::packet::ethernet::{EtherTypes, EthernetPacket}; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::tcp::TcpFlags::{ACK, FIN, PSH, RST, SYN}; use pnet::packet::tcp::TcpPacket; use std::collections::HashSet; use std::sync::mpsc::Sender; use std::time::Duration; use tracing::debug; /// The main receive handler for the network engine. /// /// The function runs in a loop and receives packets from the network interface. /// /// It uses the SYN cookie algorithm to determine if the received packet belongs /// to the scanner. pub(crate) fn receive_handler( mut rx: Box, packet_tx: Sender>, progress_tx: Sender, config: ScannerConfig, mut tls_targets: T, ) { let mut packet_count = 0; let full_read_buffer_count = 0; let progress_interval: Duration = Duration::from_secs(1); // We explicitly set the last progress to half the interval to try and // be out of sync with the transmit thread. This should spread the load // of the progress updates a bit more. let mut last_progress = std::time::Instant::now() - (progress_interval / 2); // Connection Table let mut connection_table = TcpConnectionTable::::new(); loop { match rx.next() { Ok(packet) => { // Determine if the received packet belongs to our scanner // TODO: I want to see if I can do ARP spoofing so I can filter packets to the spoofed // ARP address early here to determine it comes from the scanner. This will // allow us to skip the TCP cookie generation for a proper comparison. let ethernet_packet = unwrap_or_continue!(EthernetPacket::new(packet)); let recv_eth_src = ethernet_packet.get_source(); if handle_arp_requests(ðernet_packet, &config, &packet_tx) { // ARP request handled continue; } if ethernet_packet.get_destination() != config.src_mac || ethernet_packet.get_ethertype() != EtherTypes::Ipv4 { // Packet was not destined for us continue; } let ipv4_packet = unwrap_or_continue!(Ipv4Packet::new(ethernet_packet.payload())); let recv_src_ip = ipv4_packet.get_source(); debug!( "Received IP packet from {} to {}. It should map to scanner IP {}", recv_src_ip, ipv4_packet.get_destination(), config.src_ip ); if ipv4_packet.get_destination() != config.src_ip { // This is especially useful if we have a dedicated (e.g. spoofed) IP for the scanner continue; } if ipv4_packet.get_next_level_protocol() != IpNextHeaderProtocols::Tcp { // debug!("Received non-TCP packet"); continue; } let tcp_packet = unwrap_or_continue!(TcpPacket::new(ipv4_packet.payload())); // Client/server port naming clarity let recv_src_port = tcp_packet.get_source(); // remote server port let recv_dst_port = tcp_packet.get_destination(); // our local source port let tcp_flags = tcp_packet.get_flags(); let recv_seq = tcp_packet.get_sequence(); let recv_ack = tcp_packet.get_acknowledgement(); debug!( "Parsed TCP Packet from {}:{} to {}:{} with flags: {:?}, seq: {}, ack: {}", recv_src_ip, recv_src_port, config.src_ip, recv_dst_port, tcp_flags, recv_seq, recv_ack ); // Generate flow cookie for correlation let syn_cookie = generate_flow_cookie_v4( &config.src_ip, recv_dst_port, &recv_src_ip, recv_src_port, config.entropy, ); if recv_ack == (syn_cookie + 1) && tcp_flags == (SYN | ACK) { debug!( "Received valid SYN ACK from {}:{}", recv_src_ip, recv_src_port ); // Send ACK packet to complete the handshake let tcp_ack = unwrap_or_continue!(create_tcp_ack( &config.src_mac, &recv_eth_src, &config.src_ip, recv_dst_port, &recv_src_ip, recv_src_port, recv_ack, recv_seq + 1, )); debug!( "Sending TCP ACK to {}:{} from {}:{}", recv_src_ip, recv_src_port, config.src_ip, recv_dst_port ); packet_tx .send(tcp_ack.to_vec()) .expect("Failed to send TCP ACK packet"); // Send TLS Client Hello debug!( "Creating TLS Client Hello for {}:{}", recv_src_ip, recv_src_port ); let sni = unwrap_or_continue!(tls_targets.get_sni(recv_src_ip, recv_src_port)); debug!("SNI for {}:{} is {}", recv_src_ip, recv_src_port, sni); let (tls_client_hello, payload_len) = unwrap_or_continue!(create_tls_client_hello( &config.src_mac, &recv_eth_src, &config.src_ip, recv_dst_port, &recv_src_ip, recv_src_port, recv_ack, recv_seq + 1, sni, )); debug!( "Sending TLS Client Hello to {}:{}, from: {}:{}", recv_src_ip, recv_src_port, config.src_ip, recv_dst_port ); packet_tx .send(tls_client_hello) .expect("Failed to send TLS Client Hello packet"); // Store the connection in the connection table // TODO: These ack and stuff are wrong. recv_seq + 1 is actually the expected seq. debug!( "Storing connection for {}:{} in connection table", recv_src_ip, recv_src_port ); let connection = TcpConnection::new(recv_seq + 1, recv_ack + payload_len); connection_table.insert(syn_cookie, connection); } else if tcp_flags == (ACK) || tcp_flags == (ACK | PSH) { debug!( "Received TCP packet with ACK/ACK PSH from {}:{}", recv_src_ip, recv_src_port ); if !tcp_packet.payload().is_empty() { debug!( "TCP packet has payload len {} from {}:{}", tcp_packet.payload().len(), recv_src_ip, recv_src_port ); // If there is no connection in the connection table, we can ignore // as we've likely already closed the connection let connection = unwrap_or_continue!(connection_table.get(syn_cookie)); debug!( "Found connection in table for {}:{}", recv_src_ip, recv_src_port ); connection.add_segment(tcp_packet.payload(), recv_seq); // Parse the TLS certificate debug!( "Connection data len: {}, recv_seq: {}", connection .get_reassembled_buffer() .as_ref() .map_or(0, |buf| buf.len()), recv_seq, // connection.next_seq() ); let connection_data = unwrap_or_continue!(connection.get_reassembled_buffer()); debug!( "Connection data len: {}, recv_seq: {}", connection_data.len(), recv_seq, // connection.next_seq() ); match extract_tls12_leaf_certificate(connection_data.as_slice()) { Ok(certificate) => { // debug!( // "\n\n\nEXTRACTED CERT CHAIN: {:?}\n\n\n", // certificate_chain // ); debug!( recv_src_ip = %recv_src_ip, recv_src_port = %recv_src_port, "Successfully extracted TLS certificate chain", ); debug!( "Extracted certificate len: {} bytes from {}:{}", certificate.len(), recv_src_ip, recv_src_port ); match extract_cn_and_san(certificate) { Ok((cn, san)) => { debug!( cn = ?cn, san = ?san, recv_src_ip = %recv_src_ip, recv_src_port = %recv_src_port, "Extracted CN and SAN from certificate", ); debug!( "Extracted CN: {:?}, SANs: {:?} from {}:{}", cn, san, recv_src_ip, recv_src_port ); let mut names: HashSet = HashSet::new(); if let Some(cn) = cn { names.insert(clean_fqdn(cn)); } for san_name in san { names.insert(clean_fqdn(san_name)); } if !names.is_empty() { debug!( "Sending names {:?} from {}:{} to progress channel", names, recv_src_ip, recv_src_port ); progress_tx .send(TlsMessage::Result(( names, recv_src_ip, recv_src_port, ))) .unwrap(); } } Err(e) => { debug!( error = %e, recv_src_ip = %recv_src_ip, recv_src_port = %recv_src_port, "Failed to extract CN and SAN from certificate", ); debug!( "Failed to extract CN and SAN from certificate from {}:{}: {}", recv_src_ip, recv_src_port, e ); } } } Err(TlsExtractError::Incomplete) => continue, // Wait for more packets Err(_) => {} // Parsing failed, close the connection }; // Remove the connection from the connection table connection_table.remove(syn_cookie); // Close the connection let tcp_rst = unwrap_or_continue!(create_tcp_rst( &config.src_mac, &recv_eth_src, &config.src_ip, recv_dst_port, &recv_src_ip, recv_src_port, // Sequence number is original syn + client hello + client hello data for the RST // THIS NEEDS TO BE FIXED AS WE'RE NOW UPDATING THE SNI IN THE CLIENT HELLO. // OUR CONNECTION TABLE SHOULD KEEP TRACK OF OUR SEQUENCE NUMBER. THIS WILL ALSO // ALLOW US TO PARSE MORE INCOMING PACKETS AFTER THE FIRST PACKET REASSEMBLY // AS THE SYN COOKIE WON'T MATCH ANYMORE. // connection.next_seq() syn_cookie + 165, )); packet_tx .send(tcp_rst.to_vec()) .expect("Failed to send TCP RST packet"); } } else { connection_table.remove(syn_cookie); if tcp_flags == FIN { debug!("Received tcp FIN {:?}", tcp_flags); } else if tcp_flags == RST { debug!("Received tcp RST {:?}", tcp_flags); } else if tcp_flags == FIN | ACK { debug!("Received tcp FIN ACK {:?}", tcp_flags); } else if tcp_flags == FIN | ACK | PSH { debug!("Received tcp FIN ACK PSH {:?}", tcp_flags); } else { debug!("Received TCP packet with unmatched flags: {:?}", tcp_flags); // todo!( // "Handle other TCP packets. probably means the connection is being closed, so we can clean the connection table if not done yet." // ); } } packet_count += 1; } Err(e) => { panic!("Failed to receive packet: {}", e) } } // Send progress update every second if last_progress.elapsed() > progress_interval { if progress_tx .send(TlsMessage::RxStats { packet_count, full_read_buffer_count, }) .is_ok() { packet_count = 0 } last_progress = std::time::Instant::now(); } } } ``` ### `scanners/mrpf_tls_scanner/src/scanner.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tls_scanner/src/scanner.rs` ```rust use super::progress::{TlsMessage, TlsProgressHandler}; use super::receive::receive_handler; use super::targets::{TlsTargetTable, TlsTargetTableSingleSNI}; use super::transmit::transmit_handler; use mrpf_engine::ProgressHandler; use mrpf_engine::ScannerConfig; use mrpf_engine::network::arp::get_router_arp_address; use mrpf_engine::network::datalink::create_datalink_channel; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use std::collections::{HashMap, HashSet}; use std::net::Ipv4Addr; use std::sync::mpsc; use std::thread; /// The TLS SNI Syn Scanner. /// /// It is responsible for managing the network /// interfaces, sending and receiving packets, and handling status updates. /// /// The core of the engine is based on moving the problem of quickly scanning a large volume of /// network hosts from an IO bound to a CPU bound problem. This is done by separating the transmit /// and receive of TCP/IP network stack to two separate threads. /// /// Correlating send and receive packets is being done by SYN cookies instead of relying on sockets /// per connection. /// /// TODO: /// Spoofing MAC still doesn't work. The router doesn't seem to respond to my arp request for router address /// we're only getting it implicitly after a while due to the thing broadcasting itself. /// /// *TODO* This should become only public inside the crate. Our consumers should /// instead be using the various scanner implementations. /// /// # Example /// /// ```ignore /// use mrpf::network::datalink::get_interface; /// use mrpf::{ScannerConfig, TlsScanner}; /// use std::net::Ipv4Addr; /// /// let interface = get_interface(None).unwrap(); /// let src_mac = interface.mac.unwrap(); /// /// let config = ScannerConfig { /// interface, /// src_ip: Ipv4Addr::new(192, 168, 1, 182), /// src_mac, /// router_ip: Ipv4Addr::new(192, 168, 1, 1), /// }; /// /// let scanner = TlsScanner::new(config); /// let ipv4_addresses = vec![Ipv4Addr::new(192, 168, 1, 100)]; /// let ports = vec![80, 443]; /// /// let results = scanner.run(ipv4_addresses, ports); /// for (ip, open_ports) in results { /// debug!("IP: {}, Open Ports: {:?}", ip, open_ports); /// } /// ``` pub struct TlsScanner { /// Scanner Configuration. /// /// This holds configuration data like the source MAC and IP address, the router IP address, /// and the network interface to use. pub config: ScannerConfig, } impl TlsScanner { /// Initialize engine pub fn new(config: ScannerConfig) -> TlsScanner { // Mac spoofing still not working, perhaps it's something to do with // unicast/multicat + local/manufactor assigned code? this is encoded // in the first and second significant bits of the first octet // let src_mac: MacAddr = MacAddr::new(0x00, 0xad, 0xde, 0xad, 0xde, 0xad); TlsScanner { config } } /// Run a TCP Syn Scan with the default progress handler protocol. /// /// This will open up the datalink channels, spawn the sending and receiving threads, /// and handle status updates. /// /// # Arguments /// /// * `ipv4_addresses` - The list of IPv4 addresses to scan. /// * `ports` - The list of ports to scan. /// /// # Returns /// /// A map of names and the ip and port they were found on pub fn run( &self, targets: Vec<(Vec, Ipv4AddressRanges, PortRanges)>, ) -> HashMap> { let progress_handler = TlsProgressHandler; self.run_with_handle(targets, progress_handler) } /// Run a TCP Syn Scan with a specific progress handler protocol. /// /// This will open up the datalink channels, spawn the sending and receiving threads, /// and handle status updates. /// /// # Arguments /// /// * `ipv4_addresses` - The list of IPv4 addresses to scan. /// * `ports` - The list of ports to scan. /// * `progress_handler` - The progress handler function to handle status updates. /// /// # Returns /// /// A JoinHandle containing the results of the scan. pub fn run_with_handle< P: ProgressHandler>>, >( &self, targets: Vec<(Vec, Ipv4AddressRanges, PortRanges)>, mut progress_handler: P, ) -> HashMap> { let (mut tx, mut rx) = create_datalink_channel(&self.config.interface, 4096 * 4, 4096 * 4); // Retrieve the router mac address let router_mac = get_router_arp_address( self.config.src_mac, self.config.src_ip, self.config.router_ip, &mut *tx, &mut *rx, ) .expect("Router MAC address should be found"); // Create a channel and separate thread for handling progress updates let (progress_tx, progress_rx) = mpsc::channel(); let progress_handle = thread::Builder::new() .name("progress_thread".to_string()) .spawn(move || progress_handler.handler(progress_rx)) .expect("Progress thread should have been spawned"); // Create a channel for sending packets from the receive thread to the transmit thread let (packet_tx, packet_rx) = mpsc::channel(); // Create a thread for receiving packets let config = self.config.clone(); let progress_tx_clone = progress_tx.clone(); let receive_handle = if targets[0].0.len() > 1 { // We have multiple SNIs to scan so use the less efficient // TlsTargetTable implementation let tls_targets = TlsTargetTable::new(&targets); thread::Builder::new() .name("receive_thread".to_string()) .spawn(move || { receive_handler(rx, packet_tx, progress_tx_clone, config, tls_targets); }) .expect("Receive thread should have been spawned") } else { // We only have a single SNI to scan so we can use the more efficient // TlsTargetTableSingleSNI implementation let tls_targets = TlsTargetTableSingleSNI::new(targets[0].0[0].clone()); thread::Builder::new() .name("receive_thread".to_string()) .spawn(move || { receive_handler(rx, packet_tx, progress_tx_clone, config, tls_targets); }) .expect("Receive thread should have been spawned") }; // Create a thread for sending packets let config = self.config.clone(); let transmit_handle = thread::Builder::new() .name("transmit_thread".to_string()) .spawn(move || { transmit_handler(tx, packet_rx, progress_tx, config, router_mac, targets); }) .expect("transmit thread should have been spawned"); // When the transmit handle is finished, we can drop the receive handle // and return the results from the progress_handle. transmit_handle.join().unwrap(); drop(receive_handle); progress_handle.join().unwrap() } } ``` ### `scanners/mrpf_tls_scanner/src/targets.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tls_scanner/src/targets.rs` ```rust use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use std::collections::HashMap; use std::net::Ipv4Addr; /// Trait that returns the next SNI to scan. /// /// Here are a few use cases that we want to be able to support /// - Scan a single ip/port tuples for multiple SNI's (Single host VHOST scanning) /// - Scan large ranges of ip/port tuples for multiple SNI's (Find a list of unresponsive fqdns on given cloud/customer ip ranges) /// - Scan large ranges of ip/port tuples for a single SNI (Find a single SNI on given cloud/customer ip ranges, eg. find a origin server) /// - Scrape SNI from a vector of single ip/port/sni tuples (Get all known active domains of a target and scrape their SNIs) /// /// The last one does not need to store the ip/port/SNI combination in a hashmap, as we only need to store a single SNI. /// This is a more memory efficient way to scan large ranges of ip/port for a single SNI and is implemented by TlsTargetTableSingleSNI. /// The other scenarios are implemented by TlsTargetTable. pub(crate) trait TlsTargets { fn get_sni(&mut self, ip_address: Ipv4Addr, port: u16) -> Option; } /// Returns the next SNI to scan. /// /// It maintains a hashmap of ip address/port combination with the list of SNI names to test. SNI's /// are popped from the list as they are scanned. /// /// This implementation is useful for the following scenarios: /// - Scan a single ip/port tuples for multiple SNI's (Single host VHOST scanning) /// - Scan large ranges of ip/port tuples for multiple SNI's (Find a list of unresponsive fqdns on given cloud/customer ip ranges) /// - Scan large ranges of ip/port tuples for a single SNI (Find a single SNI on given cloud/customer ip ranges, eg. find a origin server) /// /// If you are only scanning for a single SNI, use TlsTargetTableSingleSNI instead. /// /// NOTE that depending on the number of targets, this table can grow quite large. For example: /// /// We have 10,000,000 IP addresses with 100 ports and 1 hostname. Each IP address + port + hostname pair is represented by a u32 value, which is 4 bytes. /// /// Number of pairs = 10,000,000 IP addresses * 100 ports * 1 hostname /// = 1,000,000,000 pairs /// /// Total memory usage = 1,000,000,000 pairs * 4 bytes per pair /// = 4,000,000,000 bytes /// = 4,000,000 KB /// = 4,000 MB /// = 4 GB pub(crate) struct TlsTargetTable { // A hashmap of ip address/port combination with the list of SNI names to test targets: HashMap<(Ipv4Addr, u16), Vec>, } impl TlsTargetTable { /// Create a new TlsTargetTable. pub(crate) fn new(targets: &Vec<(Vec, Ipv4AddressRanges, PortRanges)>) -> Self { let mut target_table = HashMap::new(); for (hostnames, ipv4_addresses, ports) in targets { for dst_port in ports.clone() { for dst_ip in ipv4_addresses.clone() { target_table .entry((dst_ip.into(), dst_port)) .and_modify(|e: &mut Vec| e.extend(hostnames.clone())) .or_insert(hostnames.clone()); } } } Self { targets: target_table, } } } impl TlsTargets for TlsTargetTable { /// Retrieve the next SNI to scan. /// /// # Arguments /// /// * `ip_address` - The IP address to scan /// * `port` - The port to scan /// /// # Returns /// /// The next SNI to scan or None if there are no more SNI's to scan. fn get_sni(&mut self, ip_address: Ipv4Addr, port: u16) -> Option { let target = self.targets.get_mut(&(ip_address, port))?; let sni = target.pop()?; Some(sni) } } /// Returns the next SNI to scan. /// /// This implementation is useful for the following scenario: /// - Scan ranges of ip/port tuples for a single SNI /// /// This implementation is more memory efficient than TlsTargetTable, as it only needs to store a single SNI. pub(crate) struct TlsTargetTableSingleSNI { // A hashmap of ip address/port combination with the list of SNI names to test sni: String, } impl TlsTargetTableSingleSNI { /// Create a new TlsTargetTableSingleSNI. pub(crate) fn new(sni: String) -> Self { Self { sni } } } impl TlsTargets for TlsTargetTableSingleSNI { /// Retrieve the next SNI to scan. /// /// # Arguments /// /// * `ip_address` - The IP address to scan /// * `port` - The port to scan /// /// # Returns /// /// The next SNI to scan or None if there are no more SNI's to scan. fn get_sni(&mut self, _ip_address: Ipv4Addr, _port: u16) -> Option { Some(self.sni.clone()) } } ``` ### `scanners/mrpf_tls_scanner/src/transmit.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_tls_scanner/src/transmit.rs` ```rust use super::progress::TlsMessage; use mrpf_engine::ScannerConfig; use mrpf_engine::network::datalink::send_packet; use mrpf_engine::network::tcp::create_syn_tcp_packet_v4; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use pnet::datalink::DataLinkSender; use pnet::util::MacAddr; use rand::Rng; use std::sync::mpsc::{Receiver, Sender}; use std::time::{Duration, Instant}; use tracing::{debug, info}; // Interval of status updates and reset of our packet counts const PROGRESS_INTERVAL: Duration = Duration::from_secs(1); /// The main function responsible for sending packets. /// /// To quote Masscan: /// /// ```quote /// THIS IS WHERE ALL THE EXCITEMENT HAPPENS!!!! /// 90% of CPU cycles are in this function. /// ``` pub(crate) fn transmit_handler( mut tx: Box, packet_rx: Receiver>, progress_tx: Sender, mut config: ScannerConfig, // TODO move router mac to the builder? Then we only have to do it once // instead of each scan and it makes it easier to pass on values. router_mac: MacAddr, targets: Vec<(Vec, Ipv4AddressRanges, PortRanges)>, ) { // Progress counters let mut syn_packet_per_second: u32 = 0; let mut packets_per_second: u32 = 0; let mut full_write_buffer_count = 0; let mut last_progress = Instant::now() - PROGRESS_INTERVAL; // Start rate limiter config.rate_limiter.start(); // Windows uses 49152 to 655352, linux uses 32768 to 609991, so best to generate a // source port between 32768 and 655352 to keep them guessing. // However, if we're not able to spoof our source IP address, macOS or Linux will // send RST packets for TCP connections established by us. These packets will // interfere with our scanning. To avoid this you can create a firewall rule // to drop outbound RST packets for TCP connections initiated from the OS source port range. // This is not optimal to run on your own system but acceptable for dedicated scanning systems. let mut rng = rand::rng(); // Parse EXCLUDE_OUTBOUND_TCP_PORTS env var ("low-high") into RangeInclusive, fallback to default let src_port_range = std::env::var("EXCLUDE_OUTBOUND_TCP_PORTS") .ok() .and_then(|s| { let (low, high) = s.split_once('-')?; let low: u16 = low.parse().ok()?; let high: u16 = high.parse().ok()?; if low <= high { // All a bit hacky, but we want to exclude the range from source ports Some(10000..=low - 1) } else { None } }) .unwrap_or(32768..=65535); info!(src_port_range = ?src_port_range, "Source port range for scanning"); for (hostnames, ipv4_addresses, ports) in targets { for _ in hostnames { for dst_port in ports.clone() { let src_port: u16 = rng.random_range(src_port_range.clone()); for dst_ip in ipv4_addresses.clone() { // Transmit pending packets from the packet receiver. // // These packets are generated by the receive thread. They should get precedence over // normal SYN packets, otherwise remote servers might start to retransmit or drop connections. for packet in packet_rx.try_iter() { config.rate_limiter.wait_for_token(); debug!("transmit_handler: sending queued packet {}", packet.len()); send_packet(&mut tx, &packet, &mut full_write_buffer_count); packets_per_second += 1; } // Establish new TCP connection let tcp_syn_packet = create_syn_tcp_packet_v4( &config, &router_mac, src_port, &dst_ip.into(), dst_port, ); if let Some((_, packet)) = tcp_syn_packet { config.rate_limiter.wait_for_token(); send_packet(&mut tx, &packet, &mut full_write_buffer_count); syn_packet_per_second += 1; packets_per_second += 1; } // Send status updates to our process handler and reset counters. if last_progress.elapsed() > PROGRESS_INTERVAL { send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); last_progress = Instant::now(); } } } } } // Update queued up progress statistics send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); // Process any queued packets from the receive thread for the duration // of the connection timeout * 2 (to allow for retransmits). let start_time = Instant::now(); while start_time.elapsed() < (config.connection_timeout * 2) { for packet in packet_rx.try_iter() { config.rate_limiter.wait_for_token(); debug!("transmit_handler: sending queued packet {}", packet.len()); send_packet(&mut tx, &packet, &mut full_write_buffer_count); packets_per_second += 1; } } // Send final statistic update send_progress_update( progress_tx.clone(), &mut syn_packet_per_second, &mut packets_per_second, &mut full_write_buffer_count, ); // Scan is finished, signal our process handler progress_tx .send(TlsMessage::Finished) .expect("Send finished message"); } /// Send transmit progress update #[inline] fn send_progress_update( progress_tx: Sender, syn_packet_per_second: &mut u32, packets_per_second: &mut u32, full_write_buffer_count: &mut u8, ) { progress_tx .send(TlsMessage::TxStats { syn_packet_count: *syn_packet_per_second, packet_count: *packets_per_second, full_write_buffer_count: *full_write_buffer_count, }) .unwrap(); *syn_packet_per_second = 0; *packets_per_second = 0; *full_write_buffer_count = 0; } ``` ### `scanners/mrpf_whois/README.md` **Absolute path:** `/root/mrpf/scanners/mrpf_whois/README.md` ```markdown # MRPF Whois scanner A ultra performant whois crawler built on top of the mrpf_engine. It support both legacy whois protocol and RDAP protocol. It can also do bulk queries to RDAP servers. # Copilot suggestions Awesome idea. **WHOIS data + RDAP** is pure recon fuel, and writing your own **Rust** client gives you full control (and no per‑query fees). Below is a practical blueprint: - what to build (WHOIS + RDAP), - how to deal with referrals, thick vs thin registries, and RIRs, - an architecture that respects ToU/rate‑limits, - and **ready‑to‑run Rust examples** (async, robust, and testable). I’ve cited the relevant specs and registries so you can design against the actual standards. --- ## 0) The landscape, in 60 seconds - **WHOIS (RFC 3912)** is a *TCP/43 text protocol*, human‑oriented, unstructured, and inconsistent across registries. It’s not going away, but it’s legacy and hard to parse reliably.[1](https://www.rfc-editor.org/rfc/rfc3912) - **RDAP (RFC 9082/9083)** is WHOIS’ successor over HTTPS with **JSON** responses and uniform URL patterns. It’s machine‑readable, has authentication/authorization hooks, and is maintained by the IETF. ARIN and others **recommend RDAP** for automation.[2](https://datatracker.ietf.org/doc/html/rfc9082)[3](https://www.rfc-editor.org/rfc/rfc9083.html)[4](https://www.arin.net/resources/registry/whois/rws/faq/) - **Bootstrapping RDAP**: you discover which RDAP server to use via IANA **bootstrap JSON** (e.g., `dns.json`, `ipv4.json`, `ipv6.json`, `asn.json`).[5](https://www.iana.org/assignments/rdap-dns/rdap-dns.xhtml) > TL;DR: Implement **both**: WHOIS (for completeness, some ccTLDs still rely on it) **and** RDAP (for structured, automatable data). --- ## 1) What bug bounty value do you actually get? - **Domains**: registrar, creation/expiry dates, registrant/contacts (often redacted), **nameservers** (pivot to providers), DNSSEC flags, status codes. RDAP gives normalized **status** and **events** arrays.[3](https://www.rfc-editor.org/rfc/rfc9083.html) - **IPs/ASNs**: allocation owner, abuse contacts, net ranges from **RIR WHOIS/RDAP** (ARIN/APNIC/LACNIC/RIPE/AFRINIC). ARIN exposes both port‑43 and RDAP and encourages RDAP for automation.[6](https://www.arin.net/resources/registry/whois/)[4](https://www.arin.net/resources/registry/whois/rws/faq/) - **Referrals**: Thin gTLDs like **.com/.net** point you to the registrar; you must follow **referral** to get detailed data. (Verisign runs the .com/.net registry and publishes WHOIS usage terms; high‑volume scraping can be throttled—respect ToU.)[7](https://www.verisign.com/en_US/domain-names/whois-lookup/index.xhtml) --- ## 2) Architecture for a Rust client ### 2.1 High‑level modules - `core/` - `resolver.rs`: figures out if input is **domain, IP, or ASN**, and chooses WHOIS vs RDAP. - `psl.rs`: uses Mozilla **Public Suffix List** to find eTLD+1 (e.g., `example.co.uk`)—critical for correct TLD extraction. Crates: `publicsuffix` (dynamic) or `psl` (precompiled).[8](https://docs.rs/publicsuffix/latest/publicsuffix/)[9](https://github.com/addr-rs/psl)[10](https://docs.rs/public-suffix/) - `whois/` - `client.rs`: async TCP/43 client; follows **referrals** (e.g., from IANA or Verisign to registrar WHOIS). - `parse.rs`: heuristic field extraction (server‑specific templates/regex). - `rdap/` - `bootstrap.rs`: downloads & caches IANA `dns.json`/`ipv4.json`/`ipv6.json`/`asn.json` (24h cache).[5](https://www.iana.org/assignments/rdap-dns/rdap-dns.xhtml) - `client.rs`: HTTPS with `reqwest`, JSON to structs via **serde** (per RFC 9083).[3](https://www.rfc-editor.org/rfc/rfc9083.html) - `limits/` - rate‑limiter (e.g., token bucket) + polite backoff; per‑server concurrency caps. - `cli/` - flags like `--whois-only`, `--rdap-only`, `--max-referrals`, `--json`, `--timeout`. ### 2.2 Query flow (domain) 1) **Classify** string → domain. 2) **WHOIS path** (optional): - Query `whois.iana.org` for the **TLD** to find `whois:` referral server (works for many TLDs). Or hit registry (.com/.net) **Verisign** then follow to **registrar** WHOIS (thin registry).[1](https://www.rfc-editor.org/rfc/rfc3912)[7](https://www.verisign.com/en_US/domain-names/whois-lookup/index.xhtml) 3) **RDAP path** (preferred): - Load IANA **`dns.json`**; perform longest label‑wise match (e.g., `com` → Verisign RDAP base). Build URL `/domain/example.com` per **RFC 9082** and parse **JSON** per **RFC 9083**. Cache bootstrap files.[5](https://www.iana.org/assignments/rdap-dns/rdap-dns.xhtml)[2](https://datatracker.ietf.org/doc/html/rfc9082)[3](https://www.rfc-editor.org/rfc/rfc9083.html) ### 2.3 Query flow (IP / ASN) - Use IANA **`ipv4.json`/`ipv6.json`/`asn.json`** to choose the **RIR RDAP** (most‑specific match for IP ranges).[11](https://rdap.rcode3.com/bootstrapping/iana.html) - For WHOIS: ARIN/RIPE/APNIC/LACNIC/AFRINIC servers on port 43; ARIN also supports RESTful and RDAP.[6](https://www.arin.net/resources/registry/whois/)[12](https://www.arin.net/resources/registry/whois/rws/api/) --- ## 3) Rust: minimal WHOIS client (async + referrals) ```rust // Cargo.toml (excerpt) // tokio = { version = "1", features = ["full"] } // anyhow = "1" // tokio-util = { version = "0.7", features = ["codec"] } // futures = "0.3" use anyhow::{Context, Result}; use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream}; use std::time::Duration; async fn whois_query(server: &str, query: &str) -> Result { let addr = format!("{}:43", server); let mut stream = tokio::time::timeout(Duration::from_secs(8), TcpStream::connect(&addr)) .await .context("timeout connecting whois")??; // WHOIS requests end with CRLF per RFC 3912 let req = format!("{}\r\n", query); stream.write_all(req.as_bytes()).await?; stream.shutdown().await.ok(); let mut buf = Vec::new(); stream.read_to_end(&mut buf).await?; Ok(String::from_utf8_lossy(&buf).to_string()) } // naive referral extraction: look for common fields fn find_referral(next: &str) -> Option { for line in next.lines() { let l = line.to_lowercase(); if l.starts_with("whois server:") || l.starts_with("referralserver:") || l.starts_with("refer:") { return l.split_whitespace().last().map(|s| s.trim().trim_start_matches("whois://").to_string()); } } None } // example: get TLD whois server via whois.iana.org, then follow pub async fn whois_domain_follow(domain: &str) -> Result { // 1) ask IANA for the TLD let tld = domain.rsplit('.').next().unwrap_or(domain); let iana = whois_query("whois.iana.org", tld).await?; // RFC 3912 WHOIS; many TLDs include a "whois:" line pointing to the registry WHOIS // Now follow referral (if any) if let Some(registry) = find_referral(&iana) { let reg_result = whois_query(®istry, domain).await?; // Some thin registries (.com/.net) include a "Whois Server:" line giving the registrar WHOIS if let Some(registrar) = find_referral(®_result) { let final_result = whois_query(®istrar, domain).await?; return Ok(final_result); } return Ok(reg_result); } Ok(iana) } ``` **Notes** - The wire protocol is **TCP/43, CRLF‑terminated** requests; server closes the connection when done.[1](https://www.rfc-editor.org/rfc/rfc3912) - For **.com/.net**, registry WHOIS is Verisign; it often returns `Whois Server:` pointing to the **registrar** WHOIS—follow it if you want more detail (bearing in mind GDPR redaction and ToU).[7](https://www.verisign.com/en_US/domain-names/whois-lookup/index.xhtml) --- ## 4) Rust: RDAP bootstrap + query ```rust // Cargo.toml (excerpt) // reqwest = { version = "0.12", features = ["json", "gzip", "brotli", "deflate", "rustls-tls"] } // serde = { version = "1", features = ["derive"] } // serde_json = "1" // anyhow = "1" // tokio = { version = "1", features = ["full"] } use anyhow::{Context, Result}; use reqwest::Client; use serde::Deserialize; use std::{collections::HashMap, time::Duration}; // Minimal model of IANA dns.json bootstrap #[derive(Deserialize)] struct BootstrapFile { services: Vec<(Vec, Vec)>, } async fn load_bootstrap_dns(client: &Client) -> Result { let url = "https://data.iana.org/rdap/dns.json"; Ok(client.get(url).timeout(Duration::from_secs(10)).send().await? .error_for_status()?.json::().await?) } // Longest label-wise match for TLD fn pick_rdap_base(tld: &str, boot: &BootstrapFile) -> Option { for (entries, urls) in &boot.services { if entries.iter().any(|e| e.eq_ignore_ascii_case(tld)) { return urls.iter().find(|u| u.starts_with("https://")).cloned(); } } None } // RDAP domain response (trimmed to what's commonly needed) #[derive(Deserialize, Debug)] struct RdapDomain { objectClassName: String, ldhName: Option, status: Option>, events: Option>>, nameservers: Option>>, } pub async fn rdap_domain(domain: &str) -> Result { let client = Client::builder().user_agent("asnbank-recon/0.1").build()?; let boot = load_bootstrap_dns(&client).await?; let tld = domain.rsplit('.').next().unwrap_or(domain); let base = pick_rdap_base(tld, &boot).context("no RDAP base for TLD")?; let url = format!("{}/domain/{}", base.trim_end_matches('/'), domain); // RFC 9082 path patterns; RFC 9083 JSON schema let resp = client.get(url).timeout(Duration::from_secs(10)).send().await? .error_for_status()?.json::().await?; Ok(resp) } ``` - URL patterns (`/domain/{name}`, `/ip/{addr}`, `/autnum/{asn}`) are defined by **RFC 9082**; JSON fields by **RFC 9083**.[2](https://datatracker.ietf.org/doc/html/rfc9082)[3](https://www.rfc-editor.org/rfc/rfc9083.html) - Bootstrap data is the **IANA** registry (`dns.json`, etc.). Cache locally and refresh daily.[5](https://www.iana.org/assignments/rdap-dns/rdap-dns.xhtml) --- ## 5) Getting TLDs right (PSL) To extract the TLD/eTLD+1 reliably (vital for building WHOIS and RDAP queries), use the **Public Suffix List**: - `publicsuffix` (dynamic list you load at runtime) or `psl` (pre‑compiled for speed).[8](https://docs.rs/publicsuffix/latest/publicsuffix/)[9](https://github.com/addr-rs/psl) ```rust // Cargo.toml: publicsuffix = "2" // Load PSL once at start (from file or baked-in string) use publicsuffix::{List, Psl}; let psl_text = include_str!("../assets/public_suffix_list.dat"); let list: List = psl_text.parse().unwrap(); let etld1 = list.domain(b"payments.example.co.uk").unwrap(); // "example.co.uk" ``` --- ## 6) Server‑specific behaviors worth handling - **RIPE WHOIS** adds a lot of helpful flags (`-B`, `-i`, etc.) and has documentation for query semantics (more/less specifics, inverse queries). If you add CLI passthrough for power users, mirror those flags.[13](https://www.ripe.net/documents/2192/ripe-358.pdf) - **ARIN** offers port‑43 WHOIS, a RESTful Whois‑RWS, **and RDAP**; they strongly recommend RDAP for automation. Good target for IP/ASN queries.[4](https://www.arin.net/resources/registry/whois/rws/faq/)[12](https://www.arin.net/resources/registry/whois/rws/api/) - **.com/.net (Verisign)** are **thin**: registry WHOIS generally returns registrar referral. Respect **Verisign’s WHOIS ToU** (no high‑volume automated scraping; they can throttle). Prefer RDAP for automation where possible.[7](https://www.verisign.com/en_US/domain-names/whois-lookup/index.xhtml) - **IANA bootstrap**: for domains use **longest label match** in `dns.json`; for IPs pick the **most specific** prefix in `ipv4.json/ipv6.json`; for ASNs use range entries in `asn.json`.[11](https://rdap.rcode3.com/bootstrapping/iana.html) --- ## 7) Rate‑limits, caching, and ToU (non‑negotiable) - Many registries rate‑limit or forbid high‑volume WHOIS scrapes; **Verisign** explicitly restricts automated, high‑volume queries. Implement **exponential backoff**, **per‑server tokens**, and **disk cache** (e.g., per domain for 24h).[7](https://www.verisign.com/en_US/domain-names/whois-lookup/index.xhtml) - Favor RDAP (HTTPS JSON) for scripting—**ARIN** and the other RIRs support it and encourage it.[6](https://www.arin.net/resources/registry/whois/)[4](https://www.arin.net/resources/registry/whois/rws/faq/) - If you integrate **RIPE** data, ensure your usage aligns with RIPE Database terms/access controls.[14](https://www.ripe.net/manage-ips-and-asns/db/) --- ## 8) Parsing strategy (WHOIS is messy; RDAP is not) - **WHOIS**: server‑by‑server templates. Build a small rules engine: - detect registry (from referral chain) → apply a template (map of canonical keys to regex/aliases). - normalize fields (e.g., `Creation Date`, `Created`, `created:` → `created_at`). - keep raw text for audit. - Expect redaction (GDPR) and privacy proxies. - **RDAP**: strongly typed. Deserialize into structs; preserve **status**, **entities** (roles & contacts), **events** (e.g., `registration`, `expiration`), and **links**. Values and statuses are standardized in the RDAP registries.[3](https://www.rfc-editor.org/rfc/rfc9083.html)[15](https://www.iana.org/assignments/rdap-json-values/rdap-json-values.xhtml) --- ## 9) CLI ergonomics Examples: ```bash # WHOIS then follow referrals (max 2 hops), print raw whoirust example.com --whois-only --max-referrals 2 # RDAP JSON whoirust example.com --rdap --json # IP / ASN whoirust 203.0.113.5 --rdap --json whoirust AS13335 --rdap --json # power user flags passed to RIPE whois whoirust --whois-host whois.ripe.net -B -i admin-c 203.0.113.0/24 ``` --- ## 10) Putting it together: suggested folder layout ``` whoirust/ ├─ Cargo.toml ├─ src/ │ ├─ lib.rs │ ├─ core/{resolver.rs, psl.rs} │ ├─ whois/{client.rs, parse.rs} │ ├─ rdap/{bootstrap.rs, client.rs, types.rs} │ ├─ limits/{rate.rs} │ └─ cli/main.rs └─ assets/public_suffix_list.dat ``` --- ## 11) Beyond basics (roadmap) - **RDAP search** endpoints (where enabled) for **reverse lookups** (e.g., `/entities?fn=...`) per RFC 9082.[2](https://datatracker.ietf.org/doc/html/rfc9082) - **Abuse contact extraction**: RDAP `entities` with role `abuse` (often present for RIR IP data).[3](https://www.rfc-editor.org/rfc/rfc9083.html) - **IRR** data (RIPE IRR/RADB) for route/maintainer pivots—helpful but separate from WHOIS/RDAP.[14](https://www.ripe.net/manage-ips-and-asns/db/) - Optional: integrate **OpenAPI** for RDAP (there’s a public spec) to generate types.[16](https://openapi.rdap.org/) --- ## 12) Quick test plan - **Golden fixtures**: domain/IP/ASN queries for major registries (Verisign .com, PIR .org, a ccTLD, ARIN IP, RIPE IP, APNIC IP, a few registrars). Store **raw WHOIS**, **RDAP JSON**, and your **normalized JSON**; unit‑test parsing invariants. - **Network tests**: mark as `ignored` in CI unless `LIVE_TESTS=1`. --- ## 13) Security & compliance notes - Set a clear **User‑Agent** for RDAP; add contact URL/email if possible (some servers rate‑limit unknown agents). RDAP is built for access control—if you need gated data later, the protocol supports auth (RFC 7481).[16](https://openapi.rdap.org/) - Respect registry **ToU**; throttle and cache. (Verisign’s ToU is explicit about automated high‑volume use.)[7](https://www.verisign.com/en_US/domain-names/whois-lookup/index.xhtml) --- ### Want me to… - bootstrap a **ready‑to‑build Cargo project** with the modules above, - add **serde models** for the common RDAP responses, - and wire a **rate‑limited async runtime** with `tokio`? I can generate the full project skeleton and a couple of integration tests to get you shipping today. ``` ### `scanners/mrpf_whois/src/main.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_whois/src/main.rs` ```rust fn main() { println!("Hello, world!"); } ``` ### `scanners/mrpf_whois/src/record.rs` **Absolute path:** `/root/mrpf/scanners/mrpf_whois/src/record.rs` ```rust struct WhoisRecord { pub domain: String, pub registrar: Option, pub registrant: Option, pub admin: Option, pub tech: Option, pub name_servers: Vec, pub creation_date: Option, pub expiration_date: Option, pub updated_date: Option, pub status: Vec, pub emails: Vec, pub dnssec: Option, pub raw: String, } ``` ### `service_roles_security.md` **Absolute path:** `/root/mrpf/service_roles_security.md` ```markdown Details staan hier: Deze managed policy is dit misschien problematisch cross account? { "Effect": "Allow", "Action": "sts:AssumeRole", "Resource": "arn:aws:iam::*:role/EC2ImageBuilderDistributionCrossAccountRole" }, { "Effect": "Allow", "Action": [ "logs:CreateLogStream", "logs:CreateLogGroup", "logs:PutLogEvents" ], "Resource": "arn:aws:logs:*:*:log-group:/aws/imagebuilder/*" }, Ik vermoed dat in dit soort service roles VEEL meer problemen zitten dan we denken. Wellicht kunnen we wel heel erg veel cross account dingen doen doordat de service roles worden vertrouwd. Confused deputy achtige probleem wat alessandro laat zien maar is er niet veel meer dan dat? Het vreemde bij dit verhaal is dat ik de ssm:PutParameter permissie niet had op m'n image builder role, maar de error ging over de assumed-role awsserviceroleforimagebuilder. Is het mogelijk om de awsserviceroleforimagebuilder te assumen vanaf een EC2 instance role? Als dat zo is, kan ik dus eigenlijk alles voor elkaar krijgen dat die rol mag doen in de hele AWS wereld. Dit lijkt me heel onwaarschijnlijk. Als het WEL t geval is hebben we de jackpot. Dit was de error: Integration failed with JobId '34bf7015-9b05-436e-9e2d-b01c27416418', status = 'Failed' for ARN 'arn:aws:imagebuilder:eu-west-1:014534492047:image/mathijs-worker-ami/1.0.4/1'. Error:'In region 'eu-west-1' - 'SSM Client Error: 'User: arn:aws:sts::014534492047:assumed-role/AWSServiceRoleForImageBuilder/Ec2ImageBuilderIntegrationService is not authorized to perform: ssm:PutParameter on resource: arn:aws:ssm:eu-west-1:014534492047:parameter/mathijs-worker-ami/stripped-al2023-arm64 because no identity-based policy allows the ssm:PutParameter action''' Wellicht kan ik een custom workflow aanmaken die de service role mag assumen? Ik zie zoiezo deze statements: Lijkt er dus op dat ze ssm:PutParameter hebben toegevoegd aan de managed policy sinds kort. dat zal zijn waarom het niet werkte. Ze zeggen dit: Added ssm:PutParameter to allow Image Builder to update private SSM Parameters prefixed with /imagebuilder/ with the output AMI IDs that Image Builder creates from a new build. BETEKENT DIT DAT WE ssm:PutParameter kunnen aanpassen voor andere accounts?! - trust boundary is de hele service zonder conditie - De ssm:PutParameter actie is toegestaan op resources met prefix /imagebuilder/ - ALS ER VERDER GEEN CONDITIES ZIJN, KAN IK DAN EEN PARAMETER MAKEN IN MIJN ACCOUNT MET PREFIX /IMAGEBUILDER/ EN DIE LATEN SCHRIJVEN DOOR DE SERVICE ROLE? Het enige is dat ik moet weten wat voor parameters er zijn. Het probleem is hoe kunnen we zelf deze imagebuilder service assumen. Dat lijkt dus niet te kunnen, het word gedaan in de distribution stap. Wellicht kan ik iets proberen met 'Target accounts for distribution' in de distribution configuration ## I think this documentation is just not correct It states you need to add ssm:PutParameter to the Iamge Builder role, but its actually the servie role that does this? Prerequisites for SSM output parameters Before you create a new AMI distribution configuration that sets an AWS Systems Manager Parameter Store parameter (SSM parameter), ensure that you've met the following prerequisites. Execution role When you create a pipeline or use the create-image command in the AWS CLI, you can only specify one Image Builder execution role. If you have defined an Image Builder workflow execution role, you would add any additional feature permissions to that role. Otherwise, you would create a new custom role that includes the required permissions. To store the output AMI ID in an SSM parameter during distribution, you must specify the ssm:PutParameter action in your Image Builder execution role, with the parameter listed as a resource. When you set the parameter data type to AWS EC2 Image to signal Systems Manager to validate the parameter value as an AMI ID, you must also add the ec2:DescribeImages action. Hmm, but they want you to setup a whole bunch of crossaccount roles in the target accrount for cross account distribution. It seems it's not using the service to assume cross account roles here. # Cross account distrivution role suggest a different ssm parameter name than the service role gets "Resource": "arn:aws:ssm:*:111122223333:parameter/ImageBuilder-*" vs. /imagebuilder/ in the service role managed policy: { "Effect": "Allow", "Action": [ "ssm:GetParameter", "ssm:PutParameter" ], "Resource": "arn:aws:ssm:*:*:parameter/imagebuilder/*" }, ## hoe kan ik dingen vinden die service roles doen? Met deze informatie kan ik wellicht wat beter gericht zoeken naar dingen die service roles doen: Ok, hoe kan ik dingen vinden die services doen in m'n account voor bug bounty? Cloudtrail event logs moet dit kunnen laten zien! voor SSM zie je bijvoorbeeld deze: Ec2ImageBuilderIntegrationService Als ik het event kijk zie ik deze principal: "principalId": "AROAQGYSSX6HZIFUQRIUI:Ec2ImageBuilderIntegrationService", ```json { "eventVersion": "1.11", "userIdentity": { "type": "AssumedRole", "principalId": "AROAQGYSSX6HZIFUQRIUI:Ec2ImageBuilderIntegrationService", "arn": "arn:aws:sts::014534492047:assumed-role/AWSServiceRoleForImageBuilder/Ec2ImageBuilderIntegrationService", "accountId": "014534492047", "accessKeyId": "ASIAQGYSSX6H2ZTZNORG", "sessionContext": { "sessionIssuer": { "type": "Role", "principalId": "AROAQGYSSX6HZIFUQRIUI", "arn": "arn:aws:iam::014534492047:role/aws-service-role/imagebuilder.amazonaws.com/AWSServiceRoleForImageBuilder", "accountId": "014534492047", "userName": "AWSServiceRoleForImageBuilder" }, "attributes": { "creationDate": "2025-10-23T12:01:10Z", "mfaAuthenticated": "false" } }, "invokedBy": "imagebuilder.amazonaws.com" }, "eventTime": "2025-10-23T12:01:10Z", "eventSource": "ssm.amazonaws.com", "eventName": "PutParameter", "awsRegion": "eu-west-1", "sourceIPAddress": "imagebuilder.amazonaws.com", "userAgent": "imagebuilder.amazonaws.com", "errorCode": "AccessDenied", "errorMessage": "User: arn:aws:sts::014534492047:assumed-role/AWSServiceRoleForImageBuilder/Ec2ImageBuilderIntegrationService is not authorized to perform: ssm:PutParameter on resource: arn:aws:ssm:eu-west-1:014534492047:parameter/mathijs-worker-ami/stripped-al2023-arm64 because no identity-based policy allows the ssm:PutParameter action", "requestParameters": null, "responseElements": null, "requestID": "2346a6ff-1082-4389-920f-bedcaf451114", "eventID": "1eeb468f-1bfc-43c1-83ac-ef59d8d60235", "readOnly": false, "eventType": "AwsApiCall", "managementEvent": true, "recipientAccountId": "014534492047", "eventCategory": "Management" } ``` ## Image builder workflows assume de service role Ik zie bijvoorbeeld dat hij dit doet: ListInventoryEntries Assumed de role "arn:aws:iam::014534492047:role/aws-service-role/imagebuilder.amazonaws.com/AWSServiceRoleForImageBuilder", door de "imagebuilder.amazonaws.com" Deze rol heeft een trust boundary van de hele service zonder condities. Is dat een probleem? Wellicht ergens iets in de custom workflows? Dit zijn de acties die je kan doen in een custom workflow: ``` ### `services/mrpf_bare_metal_worker/src/cap_validation.rs` **Absolute path:** `/root/mrpf/services/mrpf_bare_metal_worker/src/cap_validation.rs` ```rust use mrpf_core::Result; #[cfg(target_os = "linux")] use mrpf_core::Error; #[cfg(target_os = "linux")] use std::{fs, io}; /// Taken from https://docs.rs/caps/latest/caps/enum.Capability.html #[cfg(target_os = "linux")] #[expect(non_camel_case_types)] pub(crate) enum Capability { CAP_CHOWN = 0, CAP_DAC_OVERRIDE = 1, CAP_DAC_READ_SEARCH = 2, CAP_FOWNER = 3, CAP_FSETID = 4, CAP_KILL = 5, CAP_SETGID = 6, CAP_SETUID = 7, CAP_SETPCAP = 8, CAP_LINUX_IMMUTABLE = 9, CAP_NET_BIND_SERVICE = 10, CAP_NET_BROADCAST = 11, CAP_NET_ADMIN = 12, CAP_NET_RAW = 13, CAP_IPC_LOCK = 14, CAP_IPC_OWNER = 15, CAP_SYS_MODULE = 16, CAP_SYS_RAWIO = 17, CAP_SYS_CHROOT = 18, CAP_SYS_PTRACE = 19, CAP_SYS_PACCT = 20, CAP_SYS_ADMIN = 21, CAP_SYS_BOOT = 22, CAP_SYS_NICE = 23, CAP_SYS_RESOURCE = 24, CAP_SYS_TIME = 25, CAP_SYS_TTY_CONFIG = 26, CAP_MKNOD = 27, CAP_LEASE = 28, CAP_AUDIT_WRITE = 29, CAP_AUDIT_CONTROL = 30, CAP_SETFCAP = 31, CAP_MAC_OVERRIDE = 32, CAP_MAC_ADMIN = 33, CAP_SYSLOG = 34, CAP_WAKE_ALARM = 35, CAP_BLOCK_SUSPEND = 36, CAP_AUDIT_READ = 37, CAP_PERFMON = 38, CAP_BPF = 39, CAP_CHECKPOINT_RESTORE = 40, // some variants omitted } /// Read a capability bitmask (hex) for a specific field from /proc/self/status. /// Example fields: "CapEff:", "CapPrm:", "CapBnd:", "CapAmb:", "NoNewPrivs:" #[cfg(target_os = "linux")] fn read_status_hex(field: &str) -> io::Result { let s = fs::read_to_string("/proc/self/status")?; for line in s.lines() { if line.starts_with(field) { // Format is "CapEff:\t0000000000000000" if let Some(hex) = line.split_whitespace().nth(1) { let val = u128::from_str_radix(hex, 16) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; return Ok(val); } } } Err(io::Error::new( io::ErrorKind::NotFound, format!("{} missing in /proc/self/status", field), )) } /// Returns true if the given bit is set in CapEff (Effective) for the current process. #[cfg(target_os = "linux")] fn has_effective_cap_bit(bit: u32) -> Result { let eff = read_status_hex("CapEff:").map_err(|e| Error::Other(format!("Read CapEff failed: {e}")))?; Ok((eff & (1u128 << bit)) != 0) } /// Optional: quick diagnostic helper to print relevant capability lines. /// Safe to call; ignores errors. #[cfg(target_os = "linux")] pub(crate) fn dump_cap_diagnostics() { if let Ok(s) = fs::read_to_string("/proc/self/status") { for key in ["NoNewPrivs:", "CapBnd:", "CapPrm:", "CapEff:", "CapAmb:"].iter() { if let Some(line) = s.lines().find(|l| l.starts_with(key)) { eprintln!("{line}"); } } } } /// Check if the current process has CAP_NET_RAW in its *Effective* set. pub(crate) fn check_raw_socket_capability() -> Result<()> { #[cfg(target_os = "linux")] { let bit = Capability::CAP_NET_RAW as u32; // 13 in your enum if has_effective_cap_bit(bit)? { Ok(()) } else { Err(Error::Other( "Insufficient capabilities: need CAP_NET_RAW".into(), )) } } #[cfg(not(target_os = "linux"))] { // On non-Linux systems, this check is not applicable. Ok(()) } } /// Check if the current process has CAP_SYS_ADMIN in its *Effective* set. pub(crate) fn check_shutdown_capability() -> Result<()> { #[cfg(target_os = "linux")] { let bit = Capability::CAP_SYS_ADMIN as u32; // 21 in your enum if has_effective_cap_bit(bit)? { Ok(()) } else { Err(Error::Other( "Insufficient capabilities: need CAP_SYS_ADMIN".into(), )) } } #[cfg(not(target_os = "linux"))] { Ok(()) } } ``` ### `services/mrpf_bare_metal_worker/src/main.rs` **Absolute path:** `/root/mrpf/services/mrpf_bare_metal_worker/src/main.rs` ```rust mod cap_validation; mod shutdown; mod tracing_setup; mod worker; use cap_validation::{check_raw_socket_capability, check_shutdown_capability}; use std::env; use std::process::ExitCode; use tracing::{error, info}; use worker::start_worker; const DEFAULT_QUEUE_POLL_INTERVAL_SECS: u64 = 10; use std::time::Duration; /// Main entrypoint for the bare metal worker service. fn main() -> ExitCode { let args: Vec = env::args().collect(); if args.len() == 1 { handler() } else if args.len() == 2 && args[1] == "--validate-capabilities" { validate_binary() } else { eprintln!("Usage: {} [--validate-capabilities]", args[0]); ExitCode::FAILURE } } /// Validate if our Rust binary has the required capabilities to run as a bare metal worker. /// /// Useful for doing basic check whilst building a custom AWS AMI. fn validate_binary() -> ExitCode { info!("Validating bare metal worker binary setup..."); // // Check if the binary has the required capabilities match check_shutdown_capability() { Ok(_) => info!("Binary has the required capabilities to perform a shutdown."), Err(e) => { error!("Failed to check shutdown capability: {}", e); return ExitCode::FAILURE; } } match check_raw_socket_capability() { Ok(_) => println!("Binary has the required capabilities to use raw sockets."), Err(e) => { error!("Failed to check raw socket capability: {}", e); return ExitCode::FAILURE; } } info!("Bare metal worker binary validation successful."); ExitCode::SUCCESS } fn handler() -> ExitCode { // Parse shutdown interval from environment variable let shutdown_interval = match env::var("SHUTDOWN_INTERVAL") .ok() .and_then(|s| s.parse::().ok()) .map(Duration::from_secs) { Some(dur) => { info!( "Using configured SHUTDOWN_INTERVAL of {} seconds", dur.as_secs() ); Some(dur) } None => { info!("SHUTDOWN_INTERVAL not configured, will wait indefinitely for tasks"); None } }; // Parse queue poll interval from environment variable let optional_queue_poll_interval = env::var("QUEUE_POLL_INTERVAL") .ok() .and_then(|s| s.parse::().ok()) .map(Duration::from_secs); let queue_poll_interval: Duration = match optional_queue_poll_interval { Some(dur) => { info!( "Using configured QUEUE_POLL_INTERVAL of {} seconds", dur.as_secs() ); dur } None => { info!( "QUEUE_POLL_INTERVAL not configured, defaulting to {DEFAULT_QUEUE_POLL_INTERVAL_SECS} seconds" ); Duration::from_secs(DEFAULT_QUEUE_POLL_INTERVAL_SECS) } }; // Run the queue polling worker match tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .expect("Failed to create Tokio runtime") .block_on(start_worker(shutdown_interval, queue_poll_interval)) { Ok(_) => { info!("Worker service shutting down gracefully."); ExitCode::SUCCESS } Err(e) => { error!("Worker service encountered an error: {}", e); ExitCode::FAILURE } } } ``` ### `services/mrpf_bare_metal_worker/src/shutdown/linux.rs` **Absolute path:** `/root/mrpf/services/mrpf_bare_metal_worker/src/shutdown/linux.rs` ```rust use mrpf_core::{Error, Result}; use std::fs::{OpenOptions, read_to_string}; use std::io::Write; use std::path::Path; use std::time::Duration; /// Immediately powers off the machine by triggering the Linux Magic SysRq 'o'. /// /// - Bypasses normal shutdown (equivalent to pulling the power). /// - Requires root. /// - Works only on Linux with CONFIG_MAGIC_SYSRQ enabled and /proc mounted. pub(crate) fn force_shutdown() -> Result { // Paths for SysRq control and trigger let sysrq_enable = "/proc/sys/kernel/sysrq"; let sysrq_trigger = "/proc/sysrq-trigger"; // Sanity checks if !Path::new("/proc").exists() { return Err(Error::Other("/proc is not mounted".into())); } if !Path::new(sysrq_trigger).exists() { return Err(Error::Other( "Magic SysRq not available (missing /proc/sysrq-trigger). Kernel may be built without CONFIG_MAGIC_SYSRQ, or permissions are insufficient.".into(), )); } // Ensure SysRq is enabled. Writing "1" enables all functions. if Path::new(sysrq_enable).exists() { // If we can read current value and it's not "1", enable it. let needs_enable = match read_to_string(sysrq_enable) { Ok(s) => s.trim() != "1", Err(_) => true, // If unreadable, try to enable anyway. }; if needs_enable { let mut f = OpenOptions::new().write(true).open(sysrq_enable)?; // "1" == enable all SysRq functions (bitmask shortcut) f.write_all(b"1")?; f.flush()?; // Give the kernel a moment to apply the setting. std::thread::sleep(Duration::from_millis(10)); } } // Trigger immediate power-off: write 'o' to /proc/sysrq-trigger. // After this write succeeds, the machine powers off immediately // and no further Rust code will run. let mut t = OpenOptions::new().write(true).open(sysrq_trigger)?; t.write_all(b"o")?; t.flush()?; Ok(()) } ``` ### `services/mrpf_bare_metal_worker/src/shutdown/mod.rs` **Absolute path:** `/root/mrpf/services/mrpf_bare_metal_worker/src/shutdown/mod.rs` ```rust #[cfg(target_os = "linux")] mod linux; #[cfg(target_os = "linux")] pub(crate) use linux::force_shutdown; #[cfg(not(target_os = "linux"))] mod other; #[cfg(not(target_os = "linux"))] pub(crate) use other::force_shutdown; ``` ### `services/mrpf_bare_metal_worker/src/shutdown/other.rs` **Absolute path:** `/root/mrpf/services/mrpf_bare_metal_worker/src/shutdown/other.rs` ```rust use mrpf_core::Result; /// Fallback stub for non-Linux targets. pub fn force_shutdown() -> Result { unimplemented!("force_shutdown() is only supported on Linux targets"); } ``` ### `services/mrpf_bare_metal_worker/src/tracing_setup.rs` **Absolute path:** `/root/mrpf/services/mrpf_bare_metal_worker/src/tracing_setup.rs` ```rust use aws_sdk_cloudwatchlogs::Client as CloudWatchLogsClient; use mrpf_core::{Error, Result}; use std::env; use std::io; use std::time::Duration; use tracing::error; use tracing_subscriber::EnvFilter; use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; /// Setup tracing to send logs to CloudWatch Logs. pub(crate) async fn setup_cloudwatch_tracing(cw_client: CloudWatchLogsClient) -> Result { let log_group_name = env::var("CW_LOG_GROUP_NAME")?; let log_stream_name = env::var("CW_LOG_STREAM_NAME")?; cw_client .create_log_stream() .log_group_name(&log_group_name) .log_stream_name(&log_stream_name) .send() .await .map_err(|e| { error!( error = %e, log_group = %log_group_name, log_stream = %log_stream_name, "Failed to create CloudWatch log stream" ); Error::Other(format!("Failed to create CloudWatch log stream: {e}")) })?; let env_filter = EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy(); // stdout: info and below (trace/debug/info) let stdout_layer = tracing_subscriber::fmt::layer() .without_time() .with_target(false) .with_ansi(false) .with_writer(io::stdout); // cloudwatch layer, still using env_filter for RUST_LOG let cw_layer = tracing_cloudwatch::layer() .with_client( cw_client, tracing_cloudwatch::ExportConfig::default() .with_batch_size(5) .with_interval(Duration::from_secs(5)) .with_log_group_name(log_group_name) .with_log_stream_name(log_stream_name), ) .with_code_location(true) .with_target(false); tracing_subscriber::registry::Registry::default() .with(stdout_layer) .with(cw_layer) .with(env_filter) .init(); Ok(()) } /// Setup tracing with separate stdout and stderr layers. /// /// Honor RUST_LOG if set, default to info. #[expect(dead_code)] fn setup_tracing() { tracing_subscriber::fmt() // disabling time is handy because CloudWatch will add the ingestion time. .without_time() // remove the name of the function from every log entry .with_target(false) // disable ANSI color codes in stdout .with_ansi(false) .init(); } ``` ### `services/mrpf_bare_metal_worker/src/worker.rs` **Absolute path:** `/root/mrpf/services/mrpf_bare_metal_worker/src/worker.rs` ```rust use crate::shutdown::force_shutdown; use crate::tracing_setup::setup_cloudwatch_tracing; use aws_sdk_cloudwatchlogs::Client as CloudWatchLogsClient; use lambda_helpers::get_db_connection_details; use mrpf_core::DbClient; use mrpf_core::tasks::models::WorkerQueueMessage; use mrpf_core::tasks::{ExecutionContext, QueueClient}; use mrpf_core::{Error, Result}; use mrpf_task_worker::handle_tasks; use std::time::{Duration, Instant}; use tracing::{debug, error, info}; /// Init context object for the worker, including DB and Queue clients. async fn init_context(config: &aws_config::SdkConfig) -> Result { let queue_client = QueueClient::new(config); // This is populated in UserData script of the bare metal worker instance let public_ip = match std::env::var("PUBLIC_IPV4") { Ok(ip) => ip, Err(_) => { error!("PUBLIC_IPV4 environment variable not set"); "".into() } }; info!("Worker public IP address: {}", public_ip); match get_db_connection_details(config).await { Ok((host, port, database_name, username, password)) => { // We're not properly leveraging the WorkerRequirements here as we have the same lambda code for handling // internal and external workers. Perhaps we can do something more with ExecutionContextBuilder eventually // but maybe thats not needed in Lambda. It could be more useful perhaps in more permanent compute environments? debug!( "Successfully retrieved database connection details, assuming we're dealing with an Internal Worker" ); let db = DbClient::new( &host, port, &database_name, &username, &password, "worker_lambda", "task_manager,util,public", // TODO: perhaps we need to move this to the client itself? ) .await?; Ok(ExecutionContext::builder() .db(db) .queue(queue_client) .build()) } Err(_) => Err(Error::Other( "Bare metal worker must not have a task DB client".into(), )), } } /// Starts the main polling loop for the worker service. /// /// # Arguments /// * `shutdown_interval` - Option - The duration after which the service should shut down /// * `queue_poll_interval` - Duration - The interval between polling the SQS queue pub(crate) async fn start_worker( shutdown_interval: Option, queue_poll_interval: Duration, ) -> Result { let config = aws_config::load_from_env().await; let cw_client = CloudWatchLogsClient::new(&config); setup_cloudwatch_tracing(cw_client).await?; let context = init_context(&config).await?; let mut last_message_at = Instant::now(); // Main polling loop loop { match context .queue_client()? .receive_bare_metal_worker_message() .await { Err(e) => { // Continue to next iteration on error error!("Error with receiving or parsing message from queue: {}", e); tokio::time::sleep(queue_poll_interval).await; continue; } Ok(worker_queue_message) => match worker_queue_message { Some(worker_queue_message) => { info!("Received a message from the queue."); match &worker_queue_message { WorkerQueueMessage::StartTask { root_container_id, task, } => { info!( "Pulled StartTask message: container_id={:?}, task_id={:?}", root_container_id, task.id() ); let messages = handle_tasks(vec![(*root_container_id, task.clone())], &context) .await?; if !messages.is_empty() { info!( message_count = %messages.len(), "Sending TaskManagerQueueMessage(s) to Task Manager", ); context.queue_client()?.send_mgr_messages(&messages).await?; } else { info!("No messages to send to Task Manager"); } } } // Reset the interval counter when a message is found last_message_at = Instant::now(); } None => { debug!("No message found on SQS queue."); if let Some(interval) = shutdown_interval { let elapsed = last_message_at.elapsed(); debug!( "Time since last message: {} seconds, configured shutdown interval: {} seconds", elapsed.as_secs(), interval.as_secs() ); if elapsed >= interval { info!( "No message received within {} seconds, proceeding to shutdown.", interval.as_secs() ); break; } else { debug!("Shutdown interval not reached yet, continuing to poll."); } } else { debug!( "No shutdown interval configured; waiting indefinitely for a message." ); } } }, } tokio::time::sleep(queue_poll_interval).await; } info!("Loop finished, initiating shutdown sequence..."); // Allow time for logs to flush tokio::time::sleep(Duration::from_secs(6)).await; if let Err(e) = force_shutdown() { error!("Failed to execute shutdown command: {}", e); // TODO Terminate instance using AWS API? return Err(Error::Other(e.to_string())); } else { info!("Shutdown command executed successfully."); } Ok(()) } ``` ### `services/mrpf_cert_transparency/src/controllers/certclient.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/controllers/certclient.rs` ```rust use crate::models::retryconfig::RetryError; use crate::models::retryconfig::RetryError::{ConnectionError, RateLimited}; use crate::{ LogServer, LogServerEntry, LogServerRange, LogServerRangeState, deserialize_leaf_input, extract_cn_and_san, fqdn_cleanup, store_fqdns_until_succeeded, update_ranges_until_succeeded, }; use base64::Engine as _; use rand::Rng; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashSet; use tokio::time::{Duration, Instant, sleep}; use tracing::{debug, error, info}; const NAME_CACHE_SIZE: usize = 512; // Log servers not present on the google json file // crt.sh has a list of servers it monitors: https://crt.sh/monitored-logs const STATIC_LOG_SERVERS: [&str; 1] = [ "https://gorgon.ct.digicert.com/log/", // Found this on crt.sh when looking at asnbank.nl certificate. ]; const SERVERS_TO_SKIP: [&str; 47] = [ // Completed ranges!! // "https://sphinx.ct.digicert.com/2024h1/", // "https://ct.googleapis.com/logs/xenon2019/", // "https://ct.googleapis.com/logs/xenon2018/", // "https://wyvern.ct.digicert.com/2024h1/", // Temporary skipping as this is very often rate limiting us. I want to bring this back but first want to see if I can get some mileage on the other servers // "https://mammoth2025h1.ct.sectigo.com/", // "https://sabre2025h1.ct.sectigo.com/", // "https://ct2024.trustasia.com/log2024/", // Got permanently rate limited from my home ip it seems // "https://ct2025-b.trustasia.com/log2025b/", // Very heavily rate limited // "https://ct2025-a.trustasia.com/log2025a/", // Very heavily rate limited // "https://sphinx.ct.digicert.com/2025h1/", // Very heavily rate limited // "https://sphinx.ct.digicert.com/2024h2/", // heavily rate limit // "https://yeti2025.ct.digicert.com/log/", // heavily rate limit // "https://nessie2024.ct.digicert.com/log/", // heavily rate limit // "https://wyvern.ct.digicert.com/2024h2/", // heavily rate limit // "https://wyvern.ct.digicert.com/2025h1/", // Very heavily rate limited // "https://oak.ct.letsencrypt.org/2024h2/", // heavily rate limit // "https://nessie2025.ct.digicert.com/log/", // heavily rate limit // Following ranges are likely never reachable anymore "https://ct.trustasia.com/log2022/", // 404 error "https://sabre2024h2.ct.sectigo.com/", // DNS error "https://ct2.trustasia.com/log2024/", // DNS error "https://ct2021.trustasia.com/log2021/", // DNS Errors "https://sabre2024h1.ct.sectigo.com/", // DNS Errors "https://yeti2018.ct.digicert.com/log/", // DNS Errors "https://yeti2019.ct.digicert.com/log/", // DNS Errors "https://yeti2020.ct.digicert.com/log/", // DNS Errors "https://yeti2021.ct.digicert.com/log/", // DNS Errors "https://yeti2022.ct.digicert.com/log/", // DNS Errors "https://nessie2018.ct.digicert.com/log/", // DNS Errors "https://nessie2019.ct.digicert.com/log/", // DNS Errors "https://nessie2020.ct.digicert.com/log/", // DNS Errors "https://nessie2021.ct.digicert.com/log/", // DNS Errors "https://mammoth2024h1.ct.sectigo.com/", // DNS Errors "https://mammoth2024h1b.ct.sectigo.com/", // DNS Errors "https://mammoth2024h2.ct.sectigo.com/", // DNS Errors "https://log.gdca.com.cn/", // DNS Errors "https://log2.gdca.com.cn/", // DNS Errors "https://testflume.ct.letsencrypt.org/2019/", // DNS Errors "https://testflume.ct.letsencrypt.org/2020/", // DNS Errors "https://testflume.ct.letsencrypt.org/2021/", // DNS Errors "https://testflume.ct.letsencrypt.org/2022/", // DNS Errors "https://testflume.ct.letsencrypt.org/2023/", // DNS Errors "https://ct.browser.360.cn/2020/", // Timeout errors getting tree size "https://ct.browser.360.cn/2021/", // Timeout errors getting tree size "https://ct.browser.360.cn/2022/", // Timeout errors getting tree size "https://ct.browser.360.cn/2023/", // Timeout errors getting tree size "https://ct.browser.360.cn/v1/2020/", // Timeout errors getting tree size "https://ct.browser.360.cn/v1/2021/", // Timeout errors getting tree size "https://ct.browser.360.cn/v1/2022/", // Timeout errors getting tree size "https://ct.browser.360.cn/v1/2023/", // Timeout errors getting tree size "https://sapling.ct.letsencrypt.org/2023h1/", // 404 Errors "https://oak.ct.letsencrypt.org/2021/", // 404 Errors "https://ct.cloudflare.com/logs/nimbus2018/", // 404 Errors "https://ct.trustasia.com/log1/", // 404 Errors "https://oak.ct.letsencrypt.org/2019/", // 404 Errors "https://ct.cloudflare.com/logs/nimbus2020/", // 404 Errors "https://ct.trustasia.com/log2020/", // 404 Errors "https://ct.cloudflare.com/logs/nimbus2017/", // 404 Errors "https://oak.ct.letsencrypt.org/2024h1/", // 404 Errors "https://ct.trustasia.com/log2023/", // 404 Errors "https://oak.ct.letsencrypt.org/2020/", // 404 Errors "https://ct.cloudflare.com/logs/nimbus2019/", // 404 Errors "https://ct.trustasia.com/log2021/", // 404 Errors "https://sapling.ct.letsencrypt.org/2022h2/", // 404 Errors "https://ct.cloudflare.com/logs/nimbus2021/", // 404 Errors ]; /// A Certificate Transparency log entry. /// /// Represents an entry in the Merkle Tree, containing either an X509 certificate or a precertificate. /// The `leaf_input` is the primary certificate of interest (end-entity or precertificate), while /// `extra_data` contains the certificate chain. /// /// See [RFC 6962](https://datatracker.ietf.org/doc/html/rfc6962) for details on the structure. #[derive(Serialize, Deserialize, Debug)] pub struct CertTransparencyEntry { pub leaf_input: String, pub extra_data: String, } /// Client for interacting with Certificate Transparency logs. /// /// Supports version 1 of the CT API (most widely supported). Retrieves and processes log entries /// from a specified log server, handling rate limits and connection errors via `RetryConfig`. /// /// # Notes /// - Uses `reqwest` for HTTP requests with a custom client configuration. /// - Filters domains based on a configurable list (currently hardcoded). /// - Stores results in DynamoDB via provided helper functions. /// /// The RFC describing the API is [here](https://datatracker.ietf.org/doc/html/rfc6962) pub struct CertTransparencyClient { base_url: String, session: Client, log_server: LogServer, filter: Vec, connection_errors: u64, rate_limited_errors: u64, scan_range_errors: u64, } impl CertTransparencyClient { /// Creates a new `CertTransparencyClient` for a given log server. /// /// # Arguments /// * `log_server` - The `LogServer` instance to query. Uses the mirror URL if available. /// /// # Returns /// A configured client instance. pub fn new(log_server: LogServer) -> Self { let url = if let Some(mirror) = log_server.mirror() { mirror } else { log_server.url() } .trim_end_matches('/') .to_string(); let base_url = if url.starts_with("http") { format!("{:}/ct/v1", url) } else { format!("https://{}/ct/v1", url) }; let session = Client::builder() .danger_accept_invalid_certs(true) .timeout(Duration::from_secs(15)) .build() .expect("Failed to build reqwest client"); CertTransparencyClient { base_url, session, log_server, filter: vec![ "flowers-to-the-world.com".to_string(), "ct-test.ssl.pub".to_string(), "testing.letsencrypt.org".to_string(), ], connection_errors: 0, rate_limited_errors: 0, scan_range_errors: 0, } } /// Retrieve the URL for the log server. pub fn url(&self) -> String { self.log_server.url().to_string() } /// Retrieves the Signed Tree Head (STH) from the log server. /// /// Fetches the current Merkle Tree size and last update timestamp. /// /// # Returns /// * `Ok((timestamp, tree_size))` - The timestamp (ms since epoch) and tree size. /// * `Err(RetryError)` - On connection failure, rate limit, or max retries exceeded. /// /// # Errors /// - Returns `RetryError::ConnectionError` if the server is unreachable. /// - Retries on rate limits (HTTP 429) using `RetryConfig`. pub async fn get_sth(&mut self) -> Result<(i64, i64), RetryError> { loop { let response = self .session .get(format!("{}/get-sth", self.base_url)) .send() .await; let response = match response { Ok(response) => response, Err(e) if e.is_connect() => { // Quite a bunch of old servers don't exist anymore so we will get DNS connect errors self.connection_errors += 1; error!("Connection error: {}", e); return Err(ConnectionError); } Err(e) => match self.log_server.retry() { RateLimited(backoff) => { self.rate_limited_errors += 1; debug!( "{}: Rate limited, retrying in {} seconds: {}", self.base_url, backoff.as_secs(), e ); sleep(backoff).await; continue; } error => return Err(error), }, }; if response.status().is_success() { let response_json: Value = match response.json().await { Ok(response_json) => response_json, Err(e) => match self.log_server.retry() { RateLimited(backoff) => { self.rate_limited_errors += 1; debug!( "{}: Rate limited, retrying in {} seconds: {}", self.base_url, backoff.as_secs(), e ); sleep(backoff).await; continue; } error => return Err(error), }, }; let tree_size = response_json["tree_size"].as_i64().unwrap(); let timestamp = response_json["timestamp"].as_i64().unwrap(); // let sha256_root_hash = response_json["sha256_root_hash"].as_str().unwrap(); // let tree_head_signature = response_json["tree_head_signature"].as_str().unwrap(); return Ok((timestamp, tree_size)); // } else if response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS { } else { match self.log_server.retry() { RateLimited(backoff) => { self.rate_limited_errors += 1; debug!( "{}: Rate limited, retrying in {} seconds: {}", self.base_url, backoff.as_secs(), response.status() ); sleep(backoff).await; continue; } error => return Err(error), } } } } /// Retrieve the entries in the Merkle Tree between the given indices /// /// The maximum number of entries that a server returns varies so the /// calling function of get_entries will have to ensure that all entries /// are retrieved. /// /// # Arguments /// /// * `start` - The start index /// * `end` - The end index /// /// # Returns /// /// A vector with the entries pub async fn get_entries( &mut self, start: i64, end: i64, ) -> Result, RetryError> { let url = format!("{}/get-entries?start={}&end={}", self.base_url, start, end); loop { let response = self.session.get(&url).send().await; let response = match response { Ok(response) => response, Err(_e) => match self.log_server.retry() { RateLimited(backoff) => { self.rate_limited_errors += 1; debug!( "{}: Rate limited, retrying in {} seconds", self.base_url, backoff.as_secs() ); sleep(backoff).await; continue; } error => { self.scan_range_errors += 1; return Err(error); } }, }; if response.status().is_success() { let response_json: Value = match response.json().await { Ok(response_json) => response_json, Err(_e) => match self.log_server.retry() { RateLimited(backoff) => { self.rate_limited_errors += 1; debug!( "{}: Rate limited, retrying in {} seconds", self.base_url, backoff.as_secs() ); sleep(backoff).await; continue; } error => { self.scan_range_errors += 1; return Err(error); } }, }; match serde_json::from_value::>( response_json["entries"].clone(), ) { Ok(entries) => { if entries.is_empty() { // If we don't get any entries back, we should rate limit match self.log_server.retry() { RateLimited(backoff) => { self.rate_limited_errors += 1; debug!( "{}: Rate limited, retrying in {} seconds", self.base_url, backoff.as_secs() ); sleep(backoff).await; continue; } error => { self.scan_range_errors += 1; return Err(error); } } } else { // We got some entries back, so was a succesful request return Ok(entries); } } Err(e) => { panic!("{}: {} {}", e, url, response_json) } }; } else if response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS { match self.log_server.retry() { RateLimited(backoff) => { self.rate_limited_errors += 1; debug!( "{}: Rate limited, retrying in {} seconds", self.base_url, backoff.as_secs() ); sleep(backoff).await; continue; } error => { self.scan_range_errors += 1; return Err(error); } } } else { // Ok, we're really fucking up the server. Lets backoff a whole bunch // warn!("{} {}-{}: {}", self.base_url, start, end, response.status()); let backoff = Duration::from_secs(rand::rng().random_range(15..=30)); sleep(backoff).await; match self.log_server.retry() { RateLimited(backoff) => { self.rate_limited_errors += 1; debug!( "{}: Rate limited, retrying in {} seconds", self.base_url, backoff.as_secs() ); sleep(backoff).await; continue; } error => { self.scan_range_errors += 1; return Err(error); } } } } } /// Scan a single range async fn scan_range( &mut self, range: &mut LogServerRange, ) -> Result, (std::io::Error, HashSet)> { let mut local_names = HashSet::new(); let range_string = format!("{}-{}", range.start(), range.end()); loop { match self.get_entries(range.current(), range.end()).await { Ok(entries) => { let num_entries = entries.len() as i64; if num_entries == 0 { panic!("This should be caught by the get_entries function"); } // Parse names from entries for entry in entries { let leaf_input_bytes = base64::prelude::BASE64_STANDARD .decode(&entry.leaf_input) .unwrap(); let leaf = deserialize_leaf_input(leaf_input_bytes).unwrap(); match leaf.entry { LogServerEntry::X509Entry(cert) => { let cert = cert[3..].to_vec(); let (cn, san) = extract_cn_and_san(cert).unwrap(); if let Some(mut cn) = cn { fqdn_cleanup(&mut cn); if !self.filter.contains(&cn) { local_names.insert(cn); } } for mut name in san { fqdn_cleanup(&mut name); if !self.filter.contains(&name) { local_names.insert(name); } } } LogServerEntry::PrecertEntry(_cert) => {} } } // Update the range cursor range.set_current(num_entries); debug!( "{} - {} new entries, remaining {} (+/- {} iterations)", range.log_server(), num_entries, range.end() - range.current(), (range.end() - range.current()) / num_entries ); if range.is_done() { return Ok(local_names); } } Err(e) => { return Err(( std::io::Error::other(format!( "{} failed to retrieve range {} fully. current cursor is {}, {} entries left: {}", range.log_server(), range_string, range.current(), range.end() - range.current(), e )), local_names, )); } } } } pub async fn scan_ranges( &mut self, db_client: &aws_sdk_dynamodb::Client, ranges: Vec, ) { let range_count = ranges.len(); let mut ranges_remaining = range_count; let mut failed_range_count = 0; let mut succeeded_range_count = 0; let mut completed_ranges: Vec = Vec::new(); let mut all_names = HashSet::new(); let start_time = chrono::Utc::now(); // Flush names and ranges to database every 30 seconds let flush_time = Duration::from_secs(30); let mut last_flush_time = Instant::now(); for mut range in ranges { let range_start_time = chrono::Utc::now(); let range_string = format!("{}-{}", range.start(), range.end()); match self.scan_range(&mut range).await { Ok(mut names) => { completed_ranges.push(range); all_names.extend(names.drain()); succeeded_range_count += 1; ranges_remaining -= 1; info!( "{} {} is done in {} sec. {} remaining", self.log_server, range_string, (chrono::Utc::now() - range_start_time).num_seconds(), ranges_remaining ); } Err((e, mut names)) => { // Rate limited. failed_range_count += 1; ranges_remaining -= 1; info!("{}: {}", self.log_server, e); if range.current() > range.start() { // We have a partial range, so we can store the names we found // and mark the range as completed. Our scheduler will // see the range marked as done and should correctly // deduct the remaining range to scan. let partial_range = LogServerRange::new( range.log_server().to_string(), LogServerRangeState::Done, range.start(), range.current(), ); let range_string = format!("{}-{}", partial_range.start(), partial_range.end()); info!( "{} Adding partially completed range {} to queue, found {} names", self.log_server, range_string, names.len() ); completed_ranges.push(partial_range); all_names.extend(names.drain()); } } } // Handle accumulated names if all_names.len() > NAME_CACHE_SIZE || last_flush_time.elapsed() > flush_time { if !all_names.is_empty() { store_fqdns_until_succeeded(db_client, &all_names).await; info!( "{} Stored {} names in DB, {} ranges left (started at {})", self.log_server, all_names.len(), ranges_remaining, start_time.format("%H:%M:%S") ); all_names.clear(); } if !completed_ranges.is_empty() { update_ranges_until_succeeded(db_client, &completed_ranges).await; info!( "{} Updated {} completed ranges in DB", self.log_server.url(), completed_ranges.len() ); completed_ranges.clear(); } last_flush_time = Instant::now(); } } // Handle any potential lingering names let total_duration = chrono::Utc::now() - start_time; if !all_names.is_empty() { store_fqdns_until_succeeded(db_client, &all_names).await; info!( "{} stored {} remaining names in DB", self.log_server.url(), all_names.len() ); } if !completed_ranges.is_empty() { update_ranges_until_succeeded(db_client, &completed_ranges).await; info!( "{} updated {} completed ranges in DB", self.log_server.url(), completed_ranges.len() ); } info!( "{} completed. {} total ranges in {} minutes. {} succeeded, {} failed", self.log_server.url(), range_count, total_duration.num_minutes(), succeeded_range_count, failed_range_count, ); } } /// Retrieve all available Certificate Transparency log servers /// /// TODO: Get certificates from their special servers: /// https://github.com/google/certificate-transparency-community-site/blob/master/docs/google/known-logs.md#special-purpose-logs /// /// https://ct.googleapis.com/daedalus/ /// Daedalus logs certificates that have expired /// /// https://ct.googleapis.com/submariner/ /// Submariner only logs certificates that chain to roots that are on track for inclusion in browser roots or were trusted at some previous point announcement blog post. /// /// TODO: SSLMate CertSpotter also maintains a log list which might be easier than google's one: /// https://sslmate.com/resources/certspotter_stats /// https://feeds.sslmate.com/ct_logs.csv /// /// # Arguments /// * `fetch_only_current` - If true, only fetch the current log servers pub async fn get_log_servers( fetch_active: bool, fetch_inactive: bool, ) -> Result, reqwest::Error> { // let response: Value = reqwest::get("https://www.gstatic.com/ct/log_list/v3/log_list.json") // Apple's list is here: https://valid.apple.com/ct/log_list/current_log_list.json let response: Value = reqwest::get("https://www.gstatic.com/ct/log_list/v3/all_logs_list.json") .await? .json() .await?; let mut log_servers = HashSet::new(); let now = chrono::Utc::now(); for operator in response["operators"].as_array().unwrap() { for log in operator["logs"].as_array().unwrap() { let url = log["url"].as_str().unwrap(); if SERVERS_TO_SKIP.contains(&url) { continue; } if log.get("temporal_interval").is_none() { continue; } if let Some(_log_type) = log.get("log_type") { // info!("Skipping log type {} {}", log_type, log["url"]); continue; } let end_date = chrono::DateTime::parse_from_rfc3339( log["temporal_interval"]["end_exclusive"].as_str().unwrap(), ) .expect("Invalid end date"); if now < end_date { if fetch_active { log_servers.insert(LogServer::new(url.to_string())); } } else if fetch_inactive { log_servers.insert(LogServer::new(url.to_string())); } } } for log_server in STATIC_LOG_SERVERS { log_servers.insert(LogServer::new(log_server.to_string())); } Ok(log_servers) } ``` ### `services/mrpf_cert_transparency/src/controllers/database.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/controllers/database.rs` ```rust use crate::LogServerRange; use aws_sdk_dynamodb::Error; use aws_sdk_dynamodb::error::SdkError; use aws_sdk_dynamodb::types::AttributeValue; use aws_sdk_dynamodb::types::{DeleteRequest, PutRequest, WriteRequest}; use futures::future::join_all; use rand::Rng; use std::collections::HashSet; use tokio::time::{Duration, sleep}; use tracing::{debug, error}; /// Retrieve all the known ranges for a given log server from DynamoDB. pub async fn get_log_server_ranges( client: &aws_sdk_dynamodb::Client, log_server: String, ) -> Result, Error> { let request = client .query() .table_name("CertTransparencyTable") .key_condition_expression("#pk = :pk") .expression_attribute_names("#pk", "PK") .expression_attribute_values( ":pk", AttributeValue::S(format!("LOGSERVER#{}", log_server)), ) .send() .await?; if let Some(ranges) = request.items { let ranges = ranges.iter().map(|v| v.into()).collect(); Ok(ranges) } else { Ok(vec![]) } } /// Store a set of FQDNs in the CertTransparencyTable DynamoDB table using UpdateItem. /// /// Each FQDN is stored with a primary key (PK) of the form `D#.` /// and a sort key (SK) of the form `#`. /// /// # Arguments /// * `client` - AWS DynamoDB client. /// * `fqdns` - Set of fully qualified domain names to store. /// /// # Errors /// Returns an error if any DynamoDB operation fails or if the table is misconfigured. pub async fn store_fqdns_until_succeeded( client: &aws_sdk_dynamodb::Client, fqdns: &HashSet, ) { debug!("Inserting {} fqdns", fqdns.len()); let mut requests = Vec::with_capacity(fqdns.len()); for fqdn in fqdns { // Validate FQDN let split_fqdn = fqdn.split('.').collect::>(); if split_fqdn.len() < 2 || split_fqdn.iter().any(|&part| part.is_empty()) || fqdn.len() > 255 { debug!("Skipping invalid FQDN: {}", fqdn); continue; } // Construct primary key: D#. let primary_key = format!( "D#{}.{}", split_fqdn[split_fqdn.len() - 2], split_fqdn[split_fqdn.len() - 1] ); // Construct sort key: # let sort_key = if split_fqdn.len() > 2 { format!( "#{}", split_fqdn[..split_fqdn.len() - 2] .iter() .rev() .cloned() .collect::>() .join("#") ) } else { "#".to_string() }; // Build UpdateItem request let request = WriteRequest::builder() .put_request( PutRequest::builder() .item("PK", AttributeValue::S(primary_key)) .item("SK", AttributeValue::S(sort_key)) // .item("fqdn", AttributeValue::S(fqdn.to_string())) .build() .expect("Failed to build PutRequest"), ) .build(); requests.push(request); } batch_write_until_succeeded(client, "CertTransparencyTable", requests).await; } /// Upserts a set of FQDNs in the CertTransparencyTable DynamoDB table using UpdateItem. /// Each FQDN is stored with a primary key (PK) of the form `D#.` /// and a sort key (SK) of the form `#`. /// /// # Arguments /// * `client` - AWS DynamoDB client. /// * `fqdns` - Set of fully qualified domain names to store. /// /// # Errors /// Returns an error if any DynamoDB operation fails or if the table is misconfigured. pub async fn update_fqdns( client: &aws_sdk_dynamodb::Client, fqdns: &HashSet, ) -> Result<(), Error> { const BATCH_SIZE: usize = 25; // Reasonable batch size for UpdateItem concurrency let mut requests = Vec::with_capacity(fqdns.len()); for fqdn in fqdns { // Validate FQDN let split_fqdn = fqdn.split('.').collect::>(); if split_fqdn.len() < 2 || split_fqdn.iter().any(|&part| part.is_empty()) || fqdn.len() > 255 { debug!("Skipping invalid FQDN: {}", fqdn); continue; } // Construct primary key: D#. let primary_key = format!( "D#{}.{}", split_fqdn[split_fqdn.len() - 2], split_fqdn[split_fqdn.len() - 1] ); // Construct sort key: # let sort_key = if split_fqdn.len() > 2 { format!( "#{}", split_fqdn[..split_fqdn.len() - 2] .iter() .rev() .cloned() .collect::>() .join("#") ) } else { "#".to_string() }; // Build UpdateItem request let request = client .update_item() .table_name("CertTransparencyTable") .key("PK", AttributeValue::S(primary_key)) .key("SK", AttributeValue::S(sort_key)) .update_expression("SET #fqdn = :fqdn") .expression_attribute_names("#fqdn", "fqdn") .expression_attribute_values(":fqdn", AttributeValue::S(fqdn.to_string())) .send(); requests.push(request); } // Process requests in chunks using drain let mut remaining = requests; let mut batch_num = 0; while !remaining.is_empty() { batch_num += 1; let batch_size = remaining.len().min(BATCH_SIZE); let batch: Vec<_> = remaining.drain(..batch_size).collect(); // Take owned futures debug!("Processing batch {} of {} FQDNs", batch_num, batch.len()); let results = join_all(batch).await; for result in results.into_iter() { // Await each future match result { Ok(_) => {} Err(e) => { panic!("Error storing FQDN: {:?}", e); } } } } Ok(()) } /// Update the range entry in the database. /// /// # Arguments /// /// * `client` - The DynamoDB client to use for updating the range entry. /// * `log_server` - The log server to update the range for. /// * `range` - The message containing the log server and range to update. pub async fn update_range( client: &aws_sdk_dynamodb::Client, range: &LogServerRange, ) -> Result<(), Error> { client .put_item() .table_name("CertTransparencyTable") .item("PK", AttributeValue::S(range.primary_key())) .item("SK", AttributeValue::S(range.sort_key())) .item("state", AttributeValue::S(range.state().to_string())) .send() .await?; Ok(()) } /// Remove LogServerRanges from DynamoDB /// /// # Arguments /// /// * `client` - The DynamoDB client to use for updating the ranges /// * `log_server` - The log server to update the ranges for /// * `ranges` - The ranges to delete from the database pub async fn delete_ranges_until_succeeded( client: &aws_sdk_dynamodb::Client, ranges: &HashSet, ) { // Remove the old ranges from the database debug!("Deleting {} old ranges", ranges.len()); let mut requests = Vec::with_capacity(ranges.len()); for range in ranges { let delete_request = WriteRequest::builder() .delete_request( DeleteRequest::builder() .key("PK", AttributeValue::S(range.primary_key())) .key("SK", AttributeValue::S(range.sort_key())) .build() .expect("Failed to build DeleteRequest"), ) .build(); requests.push(delete_request); } batch_write_until_succeeded(client, "CertTransparencyTable", requests).await; } /// Update LogServerRanges in DynamoDB /// /// # Arguments /// /// * `client` - The DynamoDB client to use for updating the ranges /// * `log_server` - The log server to update the ranges for /// * `ranges` - The ranges to add to the database pub async fn update_ranges_until_succeeded( client: &aws_sdk_dynamodb::Client, ranges: &[LogServerRange], ) { debug!("Updating {} ranges", ranges.len()); let mut requests = Vec::with_capacity(ranges.len()); for range in ranges { let put_request = WriteRequest::builder() .put_request( PutRequest::builder() .item("PK", AttributeValue::S(range.primary_key())) .item("SK", AttributeValue::S(range.sort_key())) .item("state", AttributeValue::S(range.state().to_string())) .build() .expect("Failed to build PutRequest"), ) .build(); requests.push(put_request); } batch_write_until_succeeded(client, "CertTransparencyTable", requests).await; } /// Performs a batch write operation to DynamoDB, processing requests in chunks of 25 /// until all are successfully written. Retries only unprocessed items with exponential /// backoff and jitter. /// /// # Arguments /// * `client` - The DynamoDB client. /// * `table_name` - The name of the DynamoDB table. /// * `requests` - Vector of WriteRequest items to write (PutRequest or DeleteRequest). /// /// # Returns /// Returns `Ok(())` if all requests are successfully written, or an error if a non-retryable /// failure occurs. pub async fn batch_write_until_succeeded( client: &aws_sdk_dynamodb::Client, table_name: &str, mut requests: Vec, ) { let total_requests = requests.len(); let mut total_remaining = total_requests; debug!( "Starting batch write of {} requests to table {}", total_requests, table_name ); // Process requests in chunks of 25 (DynamoDB BatchWriteItem limit) while !requests.is_empty() { let chunk_size = requests.len().min(25); let chunk: Vec = requests.drain(..chunk_size).collect(); let mut attempt = 0; loop { let result = client .batch_write_item() .request_items(table_name, chunk.clone()) .send() .await; match result { Ok(response) => { // Check for unprocessed items if let Some(unprocessed) = response.unprocessed_items && let Some(chunk) = unprocessed.get(table_name) && !chunk.is_empty() { // Retry only unprocessed requests let sleep_duration = calculate_backoff(attempt); error!( "{} unprocessed items in batch write ({}/{}), retrying in {}sec", chunk.len(), total_remaining, total_requests, sleep_duration.as_secs() ); sleep(sleep_duration).await; attempt += 1; // retry only unprocessed requests continue; } // Success: all items in this chunk were processed total_remaining -= chunk.len(); debug!( "Successfully wrote batch of 25 ({}/{})", total_remaining, total_requests ); break; } Err(e) => { // Handle retryable errors (e.g., ProvisionedThroughputExceededException) if is_retryable_error(&e) { let sleep_duration = calculate_backoff(attempt); error!( "Failed to write batch ({}/{}), retrying in {}sec: {:?}", total_remaining, total_requests, sleep_duration.as_secs(), e ); sleep(sleep_duration).await; attempt += 1; continue; } else { // Non-retryable error, fail fast panic!("Non-retryable error during batch write: {:?}", e); } } } } } debug!("Completed batch write of all {} requests", total_requests); } /// Calculates exponential backoff duration with jitter in seconds. /// Base delay is 1s * 2^attempt, capped at 15s, with 0–3s jitter. fn calculate_backoff(attempt: u32) -> Duration { let base_delay = 2u64.pow(attempt.min(4)); // Cap at 16s, adjusted below let capped_delay = base_delay.min(15); // Ensure max 15s let jitter = rand::rng().random_range(0..=3); // Jitter up to 3s Duration::from_secs(capped_delay.max(1) + jitter) // Ensure min 1s } /// Determines if an error is retryable for BatchWriteItem (e.g., throttling or server errors). fn is_retryable_error( error: &SdkError, ) -> bool { match error { SdkError::ServiceError(service_err) => { matches!( service_err.err(), aws_sdk_dynamodb::operation::batch_write_item::BatchWriteItemError::ProvisionedThroughputExceededException(_) | aws_sdk_dynamodb::operation::batch_write_item::BatchWriteItemError::InternalServerError(_) | aws_sdk_dynamodb::operation::batch_write_item::BatchWriteItemError::RequestLimitExceeded(_) ) } SdkError::TimeoutError(_) => { // Timeout errors are generally retryable true } SdkError::ResponseError(_) => { // HTTP 5xx errors are generally retryable true } _ => false, } } ``` ### `services/mrpf_cert_transparency/src/controllers/mod.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/controllers/mod.rs` ```rust pub mod certclient; pub mod database; pub use certclient::{CertTransparencyClient, get_log_servers}; pub use database::{ delete_ranges_until_succeeded, get_log_server_ranges, store_fqdns_until_succeeded, update_range, update_ranges_until_succeeded, }; ``` ### `services/mrpf_cert_transparency/src/lib.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/lib.rs` ```rust pub mod controllers; pub mod models; pub mod utils; pub use controllers::{ CertTransparencyClient, delete_ranges_until_succeeded, get_log_server_ranges, get_log_servers, store_fqdns_until_succeeded, update_range, update_ranges_until_succeeded, }; pub use models::merkletree::deserialize_leaf_input; pub use models::{ LogServer, LogServerEntry, LogServerRange, LogServerRangeState, MerkleTreeLeaf, MerkleTreeLeafHeader, }; pub use utils::{aggregate_ranges, extract_cn_and_san, fqdn_cleanup, split_ranges_to_scan}; ``` ### `services/mrpf_cert_transparency/src/main.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/main.rs` ```rust // Run with cargo build && target/debug/cert_transparency use aws_config::BehaviorVersion; use aws_config::retry::RetryConfig; use aws_config::timeout::TimeoutConfig; use aws_sdk_dynamodb::Client; use base64::Engine as _; use mrpf_cert_transparency::{ CertTransparencyClient, LogServerEntry, aggregate_ranges, delete_ranges_until_succeeded, deserialize_leaf_input, extract_cn_and_san, fqdn_cleanup, get_log_server_ranges, get_log_servers, split_ranges_to_scan, update_ranges_until_succeeded, }; use mrpf_cert_transparency::{LogServer, LogServerRange, LogServerRangeState}; use std::time::Duration; use tracing::{error, info, warn}; use tracing_subscriber::fmt; use tracing_subscriber::prelude::*; #[tokio::main] async fn main() -> Result<(), Box> { panic!( "I SHOULD USE MRPF_CERT_PARSER INSTEAD OF HAVING THE CODE HERE. SHOULD JUST WORK BUT WANT TO LOOK AT THIS PROPERLY WHEN I REFACTOR THIS CODE BASE TO POSTGRES" ); let _ = setup_tracing(); // scan().await let _ = test_some_entries().await; // reset_log_servers().await?; // This will remove all log entries!! trying first with just https://gorgon.ct.digicert.com/log/ // let _ = calculate_all_completion().await; Ok(()) } /// Create DB Client /// /// A new session to the role will be assumed async fn create_db_client() -> Result> { // Build the config with the credentials provider let retry_config = RetryConfig::standard().with_max_attempts(6); let timeout_config = TimeoutConfig::builder() .connect_timeout(Duration::from_secs(12)) .operation_timeout(Duration::from_secs(30)) .operation_attempt_timeout(Duration::from_secs(10)) .build(); let creds_config = aws_config::defaults(BehaviorVersion::v2025_08_07()) .retry_config(retry_config.clone()) .timeout_config(timeout_config.clone()) .load() .await; // Set the role ARN to assume let role_arn = "arn:aws:iam::014534492047:role/spice-deployment-test-role"; // Create the assume role credentials provider let creds = aws_config::sts::AssumeRoleProvider::builder(role_arn) .session_name("mrpf_cert_transparency") .configure(&creds_config) .build() .await; // Create the DynamoDB client let dynamodb_config = aws_config::defaults(BehaviorVersion::v2025_08_07()) .retry_config(retry_config.clone()) .timeout_config(timeout_config.clone()) .credentials_provider(creds) .load() .await; let db_client = Client::new(&dynamodb_config); Ok(db_client) } fn setup_tracing() -> Result<(), Box> { // Suppress info log for AWS credentials // https://docs.aws.amazon.com/sdk-for-rust/latest/dg/logging.html // https://github.com/awslabs/aws-sdk-rust/discussions/771 // use std::env; // env::set_var("RUST_LOG", "aws_sdk_dynamodb=error,aws_config=error"); // Initialize tracing subscriber to print to stdout let fmt_layer = fmt::layer() // Remove filename/module path .with_target(false) // Keep the level (INFO, DEBUG, etc.) .with_level(true); // Add an EnvFilter layer to respect RUST_LOG environment variable // let filter_layer = EnvFilter::from_default_env(); // Build the subscriber with the custom layer let subscriber = tracing_subscriber::registry() .with(fmt_layer) // .with(filter_layer) .with(tracing_subscriber::filter::LevelFilter::INFO); // Set it as the global default tracing::subscriber::set_global_default(subscriber)?; Ok(()) } #[expect(dead_code)] async fn scan() -> Result<(), Box> { loop { // Create a new DynamoDB client for storing log ranges let db_client = create_db_client().await?; // Retrieve log servers and existing ranges stored in the database // let log_servers = get_log_servers(true, false).await?; let log_servers = get_log_servers(false, true).await?; // let log_servers = get_log_servers(true, true).await?; let mut log_server_ranges: Vec<(CertTransparencyClient, Vec)> = Vec::new(); let mut total_ranges_to_scan = 0; for log_server in log_servers { info!("Processing log server: {}", log_server); let mut client = CertTransparencyClient::new(log_server.clone()); let tree_size = match client.get_sth().await { Ok((_timestamp, tree_size)) => tree_size, Err(e) => { error!("Failed to retrieve tree size for {}: {}", log_server, e); continue; } }; let existing_ranges = get_log_server_ranges(&db_client, log_server.url().to_string()).await?; info!("Retrieved {} existing ranges", existing_ranges.len()); let (ranges_to_delete, ranges_to_add, ranges_to_scan) = aggregate_ranges(log_server.url(), existing_ranges, tree_size); total_ranges_to_scan += ranges_to_scan.len(); // Split into smaller chunks let split_size = 1024 * 8; let max_range_count = 1; let (ranges_to_add, ranges_to_scan) = split_ranges_to_scan(ranges_to_add, ranges_to_scan, split_size, max_range_count); if let Some(last_range) = ranges_to_scan.last() { let remaining_entries = tree_size - last_range.end(); info!( "Log server: {}, tree size: {}, ranges to delete: {}, ranges to add: {}, ranges to scan: {}, remaining entries after completion: {}", log_server, tree_size, ranges_to_delete.len(), ranges_to_add.len(), ranges_to_scan.len(), remaining_entries ); } // Update ranges in the database delete_ranges_until_succeeded(&db_client, &ranges_to_delete).await; update_ranges_until_succeeded(&db_client, &ranges_to_add).await; // Add ranges to scan for this log server if ranges_to_scan.is_empty() { info!("No ranges to scan for {}", log_server); continue; } else { log_server_ranges.append(&mut vec![(client, ranges_to_scan)]); } } if total_ranges_to_scan == 0 { info!( "Wow!! Have we really scanned everything?!? I'll probably never reach this point" ); break; } let start_time = chrono::Utc::now(); let mut remaining_log_servers = log_server_ranges.len(); info!( "Starting scan of {} ranges across {} log servers at {}", total_ranges_to_scan, log_server_ranges.len(), start_time.format("%H:%M:%S") ); // Asynchronously scan all log servers in parallel let mut tasks = Vec::new(); for (mut client, ranges) in log_server_ranges { let task = tokio::spawn(async move { let db_client = create_db_client().await.unwrap(); client.scan_ranges(&db_client, ranges).await; }); tasks.push(task); } // Wait for all tasks to complete for task in tasks { match task.await { Err(e) => { error!("Task failed: {:?}", e); } _ => { remaining_log_servers -= 1; info!("{} log servers remaining", remaining_log_servers); } } } let end_time = chrono::Utc::now(); info!( "Finished scan at {}, took total of {} minutes", end_time.format("%H:%M:%S"), (end_time - start_time).num_minutes() ); // Sleep for a while before starting the next scan tokio::time::sleep(Duration::from_secs(60)).await; } Ok(()) } /// Testing a few known entries to ensure our parsing is working #[expect(dead_code)] async fn test_some_entries() -> Result<(), Box> { let log_server = LogServer::new("https://gorgon.ct.digicert.com/log".to_string()); let mut client = CertTransparencyClient::new(log_server); // entry 18905 should be asnbank.nl with SAN names DNS:e-mail.asnbank.nl and DNS:asnbank.nl let entries = client.get_entries(18905, 18905).await?; for entry in entries { let leaf_input_bytes: Vec = base64::prelude::BASE64_STANDARD .decode(&entry.leaf_input) .unwrap(); let leaf = deserialize_leaf_input(leaf_input_bytes).unwrap(); match leaf.entry { LogServerEntry::X509Entry(cert) => { // println!( // "cert: {:?}", // cert.iter() // .map(|b| format!("{:02x}", b)) // .collect::>() // .join(" ") // ); let cert = cert[3..].to_vec(); // I don't know what the first 3 bytes are? 00 0c 1d let (cn, san) = extract_cn_and_san(cert).unwrap(); if let Some(mut cn) = cn { fqdn_cleanup(&mut cn); println!("CN: {}", cn); } for mut name in san { fqdn_cleanup(&mut name); println!("SAN: {}", name); } } LogServerEntry::PrecertEntry(_cert) => { println!("Precert entry"); } } } Ok(()) } /// Calculate the completion percentage of each log server #[expect(dead_code)] async fn calculate_all_completion() -> Result<(), Box> { // Create a new DynamoDB client for storing log ranges let db_client = create_db_client().await?; // Retrieve log servers and existing ranges stored in the database let active_log_servers = get_log_servers(true, false).await?; println!( "\n\nQuerying {} active log servers:", active_log_servers.len() ); for log_server in active_log_servers { calculate_completion(&db_client, log_server).await; } let inactive_log_servers = get_log_servers(false, true).await?; println!( "\n\nQuerying {} inactive log servers:", inactive_log_servers.len() ); for log_server in inactive_log_servers { calculate_completion(&db_client, log_server).await; } Ok(()) } /// Reset all log servers in database. Note, this will remove ALL LOG ENTRIES /// which means we'll have to wait months for the logs to be re-scanned #[expect(dead_code)] async fn reset_log_servers() -> Result<(), Box> { info!("Resetting all log servers in database"); // Create a new DynamoDB client for removing log ranges let db_client = create_db_client().await?; let log_servers = get_log_servers(true, true).await?; for log_server in log_servers.iter() { info!("log server: {}", log_server); if log_server.url() == "https://gorgon.ct.digicert.com/log/" { let ranges: std::collections::HashSet<_> = get_log_server_ranges(&db_client, log_server.url().to_string()) .await? .into_iter() .collect(); info!("Deleting {} ranges for {}", ranges.len(), log_server); delete_ranges_until_succeeded(&db_client, &ranges).await; } } Ok(()) } #[expect(dead_code)] async fn calculate_completion(db_client: &Client, log_server: LogServer) { let mut client = CertTransparencyClient::new(log_server); let tree_size = match client.get_sth().await { Ok((_timestamp, tree_size)) => tree_size, Err(e) => { error!("Failed to retrieve tree size for {}: {}", client.url(), e); return; } }; let existing_ranges = get_log_server_ranges(db_client, client.url()) .await .expect("Log Server ranges should be retrievable"); let mut completed_entries = 0; for range in existing_ranges { if range.state() == LogServerRangeState::Done { completed_entries += range.end() - range.start() + 1; } } let completion_percentage = (completed_entries as f64) / (tree_size as f64) * 100.0; let count_message = format!("{}/{}", completed_entries, tree_size); let message = format!( "{:>5.2}% - {:>19} {}", completion_percentage, count_message, client.url(), ); if completion_percentage >= 99.0 { info!(message); } else if completion_percentage > 10.0 { warn!(message); } else { error!(message); } } ``` ### `services/mrpf_cert_transparency/src/models/logserver.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/models/logserver.rs` ```rust use super::retryconfig::{RetryConfig, RetryError}; // Taken from [here](https://github.com/search?q=repo%3Agoogle%2Fcertificate-transparency-community-site%20mirrors&type=code) const GOOGLE_LOG_SERVER_MIRRORS: &[(&str, &str)] = &[ ( "https://ct.googleapis.com/logs/eu1/mirrors/letsencrypt_oak2023/", "https://oak.ct.letsencrypt.org/2023/", ), ( "https://ct.googleapis.com/logs/eu1/mirrors/letsencrypt_oak2022/", "https://oak.ct.letsencrypt.org/2022/", ), ( "https://ct.googleapis.com/logs/eu1/mirrors/digicert_yeti2022_2/", "https://yeti2022-2.ct.digicert.com/log/", ), ( "https://ct.googleapis.com/logs/eu1/mirrors/digicert_yeti2023/", "https://yeti2023.ct.digicert.com/log/", ), ( "https://ct.googleapis.com/logs/us1/mirrors/digicert_nessie2022/", "https://nessie2022.ct.digicert.com/log/", ), ( "https://ct.googleapis.com/logs/us1/mirrors/digicert_nessie2023/", "https://nessie2023.ct.digicert.com/log/", ), ( "https://ct.googleapis.com/logs/us1/mirrors/cloudflare_nimbus2022/", "https://ct.cloudflare.com/logs/nimbus2022/", ), ( "https://ct.googleapis.com/logs/us1/mirrors/cloudflare_nimbus2023/", "https://ct.cloudflare.com/logs/nimbus2023/", ), ( "https://ct.googleapis.com/logs/us1/mirrors/cloudflare_nimbus2024/", "https://ct.cloudflare.com/logs/nimbus2024/", ), ( "https://ct.googleapis.com/logs/us1/mirrors/cloudflare_nimbus2025/", "https://ct.cloudflare.com/logs/nimbus2025/", ), ( "https://ct.googleapis.com/logs/us1/mirrors/comodo_mammoth/", "https://mammoth.ct.comodo.com/", ), ( "https://ct.googleapis.com/logs/us1/mirrors/comodo_sabre/", "https://sabre.ct.comodo.com/", ), ]; #[derive(Debug, Clone)] pub struct LogServer { url: String, mirror: Option, retry_config: RetryConfig, } impl LogServer { /// Initialize a new LogServer. /// /// We try to find a mirror log server is present. pub fn new(url: String) -> Self { for (mirror_url, original_url) in GOOGLE_LOG_SERVER_MIRRORS { if &url == original_url { return LogServer { url: original_url.to_string(), mirror: Some(mirror_url.to_string()), retry_config: RetryConfig::default(), }; } } LogServer { url, mirror: None, retry_config: RetryConfig::default(), } } /// Retrieve the URL of the log server pub fn url(&self) -> &str { &self.url } /// Retrieve the mirror URL of the log server pub fn mirror(&self) -> Option<&str> { self.mirror.as_deref() } /// Calculate backoff time for retrying pub fn retry(&mut self) -> RetryError { self.retry_config.rate_limiter(None) } } impl PartialEq for LogServer { fn eq(&self, other: &Self) -> bool { self.url == other.url } } impl Eq for LogServer {} impl std::hash::Hash for LogServer { fn hash(&self, state: &mut H) { self.url.hash(state); } } impl std::fmt::Display for LogServer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(mirror) = &self.mirror { write!(f, "{}", mirror) } else { write!(f, "{}", self.url) } } } ``` ### `services/mrpf_cert_transparency/src/models/logserverentry.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/models/logserverentry.rs` ```rust /// Defines the type of log entry as per RFC 6962, Section 3.1. /// /// ```text /// enum { x509_entry(0), precert_entry(1), (65535) } LogServerEntryType; /// ``` /// /// See RFC 6962, Section 3.1: https://tools.ietf.org/html/rfc6962#section-3.1 #[derive(Debug, PartialEq)] pub enum LogServerEntryType { X509LogServerEntryType = 0, PrecertLogServerEntryType = 1, } /// Represents the entry field in a MerkleTreeLeaf, varying by log_entry_type. #[derive(Debug)] pub enum LogServerEntry { X509Entry(Vec), // ASN.1 encoded certificate PrecertEntry(Vec), // ASN.1 encoded pre-certificate // PrecertEntry(PrecertEntry), // PrecertEntry containing issuer key hash and server certificate. TODO: Implement the parser } /// Represents a PrecertEntry, containing the issuer key hash and server certificate. #[derive(Debug, PartialEq)] pub struct PrecertEntry { pub leaf_cert: [u8; 32], // 32-byte issuer key hash pub server_cert: Vec, // DER-encoded server certificate } ``` ### `services/mrpf_cert_transparency/src/models/logserverrange.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/models/logserverrange.rs` ```rust use aws_sdk_dynamodb::types::AttributeValue; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use tracing::debug; // Ensure a fixed length for the range string in DynamoDB const RANGE_ZERO_PADDING: usize = 12; /// The state of a LogServerRange #[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)] pub enum LogServerRangeState { Done, Pending, } impl std::fmt::Display for LogServerRangeState { /// Serialize the LogServerRangeState to a string fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let s = match self { LogServerRangeState::Done => "done", LogServerRangeState::Pending => "pending", }; write!(f, "{}", s) } } /// A Single range for a log server /// /// The ranges are used to determine what entries of the tree have been processed or /// are currently being processed. #[derive(Debug, Clone, PartialEq, Eq)] pub struct LogServerRange { /// The log server log_server: String, /// The state of the range state: LogServerRangeState, /// The start index of the range start: i64, /// The end index of the range end: i64, /// Current index of the range during scanning current: i64, } impl std::fmt::Display for LogServerRange { /// Serialize the LogServerRange to a string fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{} {}-{} {}", self.log_server, self.start, self.end, self.state ) } } // Implement Hash for LogServerRange impl Hash for LogServerRange { fn hash(&self, state: &mut H) { self.state.hash(state); // Hash the state field self.start.hash(state); // Hash the start field self.end.hash(state); // Hash the end field } } /// Deserialize the LogServerRange struct from a HashMap for DynamoDB. impl From<&HashMap> for LogServerRange { /// The PK is formatted as LOGSERVER#{log_server} /// The sort key is formatted as RANGE#{start}-{end} fn from(map: &HashMap) -> Self { let log_server = map .get("PK") .and_then(|v| v.as_s().ok()) .unwrap() .split("#") .last() .unwrap() .to_string(); let sk = map.get("SK").and_then(|v| v.as_s().ok()).unwrap(); let parts: Vec<&str> = sk.split('#').collect(); if parts.len() != 2 { panic!("Invalid sort key format: {}", sk); } let range: Vec<&str> = parts[1].split('-').collect(); if range.len() != 2 { panic!("Invalid range format"); } let start = range[0].parse().unwrap(); let end = range[1].parse().unwrap(); let state = match map .get("state") .and_then(|v| v.as_s().ok()) .unwrap() .as_str() { "done" => LogServerRangeState::Done, "pending" => LogServerRangeState::Pending, _ => panic!("Invalid state"), }; let current = match state { LogServerRangeState::Done => end, LogServerRangeState::Pending => start, }; Self { log_server, state, start, end, current, } } } impl LogServerRange { /// Initialize a new LogServerRange /// /// # Arguments /// /// * `log_server` - The log server /// * `state` - The state of the range /// * `start` - The start index of the range /// * `end` - The end index of the range pub fn new(log_server: String, state: LogServerRangeState, start: i64, end: i64) -> Self { let (start, end) = if start < end { (start, end) } else { (end, start) }; let current = match state { LogServerRangeState::Done => end, LogServerRangeState::Pending => start, }; Self { log_server, state, start, end, current, } } /// Get log server pub fn log_server(&self) -> &str { &self.log_server } /// Get the start index of the range pub fn start(&self) -> i64 { self.start } /// Get the current index of the range pub fn current(&self) -> i64 { self.current } // Get the end index of the range pub fn end(&self) -> i64 { self.end } /// Update the current index of the range with the new retrieved entries. /// /// The num_entries is added to the current index to update the current index. /// When we've reached the end of the range, the state is set to Done. pub fn set_current(&mut self, num_entries: i64) { self.current += num_entries; if self.current < self.start { panic!( "Invalid current index {} for range {}-{}", self.current, self.start, self.end ); } else if self.current >= self.end { self.current = self.end; self.state = LogServerRangeState::Done; } } /// Update the end index of the range pub fn set_end(&mut self, end: i64) { self.end = end; } /// Check if the range is completed. pub fn is_done(&self) -> bool { self.state == LogServerRangeState::Done } /// Combine another LogServerRange into a new one if they overlap or are adjacent /// /// If the ranges state overlap but the state is incorrect we panic as this is an invalid state. /// /// # Arguments /// /// * `other` - The other LogServerRange to combine pub fn combine(&self, other: &LogServerRange) -> Option { if self.start == other.start && self.end == other.end { // The ranges are equal if self.state == LogServerRangeState::Done || other.state == LogServerRangeState::Done { // If one of the ranges is done, we will set the combined range to done Some(LogServerRange::new( self.log_server.clone(), LogServerRangeState::Done, self.start, self.end, )) } else { Some(LogServerRange::new( self.log_server.clone(), LogServerRangeState::Pending, self.start, self.end, )) } } else if self.start <= other.end + 1 && other.start <= self.end + 1 { // The ranges overlap or are adjacent if self.state != other.state { // if self.start <= other.end && other.start <= self.end { // warn!( // "Ranges {}-{} {} and {}-{} {} overlap but have different states, defaulting to pending", // self.start, self.end, self.state, // other.start, other.end, other.state // ); // return Some(LogServerRange::new( // self.log_server.clone(), // LogServerRangeState::Pending, // start, // end, // )); // } else { // // Ranges that are just adjacent are correct and should not be merged when in a different state // return None; // } // We've got the ranges working a lot better now. This was causing some issues when I switched around // range sizes somehow. For instance had the following when running locally with a lot smaller range: // 2025-03-02T21:58:25.444475Z WARN Ranges 0-860297 done and 860265-868457 pending overlap but have different states, defaulting to pending // 2025-03-02T21:58:25.444497Z WARN Ranges 0-868457 pending and 860298-860330 done overlap but have different states, defaulting to pending // 2025-03-02T21:58:25.444509Z WARN Ranges 0-868457 pending and 860331-860363 done overlap but have different states, defaulting to pending // 2025-03-02T21:58:25.444520Z WARN Ranges 0-868457 pending and 860364-860396 done overlap but have different states, defaulting to pending // This basically reset a LOT of the progress I had made, so I'm going to remove this for now. Be careful with enabling this again, really // check why this happened. None } else { let start = self.start.min(other.start); let end = self.end.max(other.end); Some(LogServerRange::new( self.log_server.clone(), self.state, start, end, )) } } else { debug!( "Ranges {}-{} {} and {}-{} {} do not overlap", self.start, self.end, self.state, other.start, other.end, other.state ); None } } /// Split LogServerRange into smaller ranges. /// /// # Arguments /// /// * `size` - The maximum size of a single range /// * `max_split` - The maximum amount of ranges to split into pub fn split(&self, size: i64, max_split: usize) -> Vec { let mut ranges = Vec::new(); let mut start = self.start; let end = self.end; while start < end { let next_end = (start + size).min(end); ranges.push(LogServerRange::new( self.log_server.clone(), self.state, start, next_end, )); start = next_end + 1; if ranges.len() >= max_split { // We're just going to leave the rest of the range as is // as we've reached the maximum amount of ranges to scan. // We will not add it to the database or SQS. It should // pick itself up in the next run. // if start <= end { // ranges.push(LogServerRange::new(self.state, start, end)); // } break; } } ranges } /// Generate the DynamoDB primary key for the range /// The PK is formatted as LOGSERVER#{log_server} pub fn primary_key(&self) -> String { format!("LOGSERVER#{}", self.log_server) } /// Generate the DynamoDB sort key for the range /// The sort key is formatted as RANGE#{start}-{end} pub fn sort_key(&self) -> String { format!( "RANGE#{:0digits$}-{:0digits$}", self.start, self.end, digits = RANGE_ZERO_PADDING ) } pub fn state(&self) -> LogServerRangeState { self.state } } ``` ### `services/mrpf_cert_transparency/src/models/merkletree.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/models/merkletree.rs` ```rust use super::logserverentry::{LogServerEntry, LogServerEntryType, PrecertEntry}; #[allow(unused_imports)] use base64::Engine as _; /// Represents a Merkle Tree Leaf Header as defined in RFC 6962, Section 3.4. /// /// The header is used to parse and determine the LogServerEntryType. After this we /// will parse the remaining bytes as the LogServerEntry. and initialize a MerkleTreeLeaf. /// /// Ensure structs are C-compatible with the exact memory layout. TODO: Do i also need repr packed to ensure /// no padding is added? IS THIS PERHAPS WHY I HAD ISSUES WITH HAVING TO STRIP THE FIRST FEW BYTES? #[repr(C)] #[derive(Debug)] pub struct MerkleTreeLeafHeader { version: u8, // 1 byte merkle_leaf_type: u8, // 1 byte timestamp: u64, // 8 bytes, big-endian log_entry_type: u16, // 2 bytes, big-endian } /// Represents a Merkle Tree Leaf as defined in RFC 6962, Section 3.4. /// /// ```text /// struct { /// Version version; // 1 byte /// MerkleLeafType leaf_type; // 1 byte /// uint64 timestamp; // 8 bytes /// LogServerEntryType entry_type; // 2 bytes /// select (entry_type) { /// case x509_entry: X509ChainEntry; /// case precert_entry: PrecertChainEntry; /// } entry; /// } MerkleTreeLeaf; /// ``` /// /// See RFC 6962, Section 3.4: https://tools.ietf.org/html/rfc6962#section-3.4 #[derive(Debug)] pub struct MerkleTreeLeaf { pub version: u8, pub merkle_leaf_type: u8, pub timestamp: u64, pub log_entry_type: LogServerEntryType, pub entry: LogServerEntry, } /// Parses a PrecertEntry from the entry bytes. /// Expects a SEQUENCE { OCTET STRING (32 bytes), OCTET STRING (server cert) }. /// TODO: NOT WORKING?! #[expect(dead_code)] fn parse_precert_entry(entry_bytes: Vec) -> Result { const LEAF_CERT_SIZE: usize = 32; const PREFIX_SIZE: usize = 3; // Known prefix: 00 03 1a const SUFFIX_SIZE: usize = 2; // Known suffix: 00 00 if entry_bytes.len() < LEAF_CERT_SIZE + PREFIX_SIZE + 1 { return Err("Insufficient bytes for PrecertEntry".to_string()); } // Extract leaf_cert (first 32 bytes) let leaf_cert: [u8; 32] = entry_bytes[..LEAF_CERT_SIZE] .try_into() .map_err(|_| "Failed to copy leaf_cert".to_string())?; // Skip the prefix (00 03 1a) and check for SEQUENCE let cert_start = LEAF_CERT_SIZE + PREFIX_SIZE; if cert_start >= entry_bytes.len() || entry_bytes[cert_start] != 0x30 { return Err("Invalid server_cert: does not start with SEQUENCE".to_string()); } // Extract server_cert, excluding the suffix (00 00) let cert_end = entry_bytes .len() .checked_sub(SUFFIX_SIZE) .ok_or("Insufficient bytes for server_cert")?; if cert_end <= cert_start { return Err("Invalid server_cert: empty after prefix/suffix".to_string()); } let server_cert = entry_bytes[cert_start..cert_end].to_vec(); println!( "server_cert: {:?}", server_cert .iter() .map(|b| format!("{:02x}", b)) .collect::>() .join(" ") ); use crate::extract_cn_and_san; let (cn, san) = extract_cn_and_san(server_cert.to_vec()).unwrap(); println!("CN: {:?}", cn); println!("SAN: {:?}", san); // Validate server_cert length (from SEQUENCE header) if server_cert.len() < 4 { return Err("Invalid server_cert: too short for SEQUENCE".to_string()); } let seq_len = if server_cert[1] & 0x80 == 0 { server_cert[1] as usize } else { let len_bytes = (server_cert[1] & 0x7F) as usize; if len_bytes > 2 || server_cert.len() < 2 + len_bytes { return Err("Invalid server_cert: malformed length".to_string()); } let mut len = 0; for &b in &server_cert[2..2 + len_bytes] { len = (len << 8) | b as usize; } len }; if seq_len + 2 + (server_cert[1] & 0x7F) as usize != server_cert.len() { return Err("Invalid server_cert: length mismatch".to_string()); } Ok(PrecertEntry { leaf_cert, server_cert, }) } /// Deserializes a `leaf_input` from a raw Vec into a MerkleTreeLeaf. /// /// This function assumes the input is already base64-decoded from the JSON response. /// It parses the binary structure according to RFC 6962, Section 3.4. /// Deserializes a `leaf_input` from a raw Vec into a MerkleTreeLeaf. /// /// This function assumes the input is already base64-decoded from the JSON response. /// It parses the binary structure according to RFC 6962, Section 3.4: /// - version: 1 byte /// - merkle_leaf_type: 1 byte /// - timestamp: 8 bytes (big-endian) /// - log_entry_type: 2 bytes (big-endian) /// - entry: remaining bytes pub fn deserialize_leaf_input(leaf_input: Vec) -> Result { const HEADER_SIZE: usize = 1 + 1 + 8 + 2; // 12 bytes if leaf_input.len() < HEADER_SIZE { return Err("Insufficient bytes for MerkleTreeLeaf header".to_string()); } let version = leaf_input[0]; let merkle_leaf_type = leaf_input[1]; let timestamp = u64::from_be_bytes( leaf_input[2..10] .try_into() .map_err(|_| "Failed to parse timestamp".to_string())?, ); let log_entry_type_raw = u16::from_be_bytes( leaf_input[10..12] .try_into() .map_err(|_| "Failed to parse log_entry_type".to_string())?, ); let log_entry_type = match log_entry_type_raw { 0 => LogServerEntryType::X509LogServerEntryType, 1 => LogServerEntryType::PrecertLogServerEntryType, _ => { return Err(format!( "Invalid LogServerEntryType: {}", log_entry_type_raw )); } }; let entry_bytes = leaf_input[HEADER_SIZE..].to_vec(); let entry = match log_entry_type { LogServerEntryType::X509LogServerEntryType => LogServerEntry::X509Entry(entry_bytes), LogServerEntryType::PrecertLogServerEntryType => LogServerEntry::PrecertEntry(entry_bytes), // LogServerEntryType::PrecertLogServerEntryType => { // let precert = parse_precert_entry(entry_bytes)?; // LogServerEntry::PrecertEntry(precert) // } }; Ok(MerkleTreeLeaf { version, merkle_leaf_type, timestamp, log_entry_type, entry, }) } #[cfg(test)] mod tests { use super::*; use crate::extract_cn_and_san; const VALID_CERT_LEAF_INPUT: &str = "AAAAAAGVVfJpPQAAAAaUMIIGkDCCBXigAwIBAgISA2XT8AYZc60H2sbYwyidAjLpMA0GCSqGSIb3DQEBCwUAMDMxCzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MQwwCgYDVQQDEwNSMTEwHhcNMjUwMzAyMDcyMzE0WhcNMjUwNTMxMDcyMzEzWjAoMSYwJAYDVQQDEx13ZWJtYWlsLnBuZC5keGMubXlibHVlaG9zdC5tZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMDAG+eH3Fb4kwEhzpPQAob/sjlkjgwKyXMpXZUOo8h5ELfVFnOU1eS8DmB3S3Dh3Al6lXR13gCaaddd/KPjY2AFpwZsvSzQe+CRswbzQB7Y+HAxRlYv2zLo9bmb7U5Itm23vOz0Cv3pJ8WujccAl8qgU7pPDRXJAi3sKqDU0DR1D9uai4a5uz4nvcTZMbQ2X1BNbvvWN3BqCsVnMuw7EtiR4AQV7Ri8Wj5g3VJt9iPGiWbSXqqYGFhfkT8yJxJgiFI+SOdf7HiF9N3bvTvsv4yxYbrSSlhAvydkwqg2CBC4pxbJcjDGTGQRVHMTZUwJuzRmQ6o/Y1qSL+a0m+QIcSsCAwEAAaOCA6cwggOjMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUq4Ed/lJGnoQurntQjCmkE63t2gIwHwYDVR0jBBgwFoAUxc9GpOr0w8B6bJXELbBeki8m47kwVwYIKwYBBQUHAQEESzBJMCIGCCsGAQUFBzABhhZodHRwOi8vcjExLm8ubGVuY3Iub3JnMCMGCCsGAQUFBzAChhdodHRwOi8vcjExLmkubGVuY3Iub3JnLzCCAa0GA1UdEQSCAaQwggGggiJhdXRvZGlzY292ZXIucG5kLmR4Yy5teWJsdWVob3N0Lm1lghxjcGFuZWwucG5kLmR4Yy5teWJsdWVob3N0Lm1lgiFjcGNhbGVuZGFycy5wbmQuZHhjLm15Ymx1ZWhvc3QubWWCIGNwY29udGFjdHMucG5kLmR4Yy5teWJsdWVob3N0Lm1lghFoYXJyaXNiaWxsaW5nLmNvbYIWbWFpbC5oYXJyaXNiaWxsaW5nLmNvbYIabWFpbC5wbmQuZHhjLm15Ymx1ZWhvc3QubWWCGW1haWwucmVuZWVvZmFsbHRyYWRlcy5jb22CFXBuZC5keGMubXlibHVlaG9zdC5tZYIUcmVuZWVvZmFsbHRyYWRlcy5jb22CHXdlYmRpc2sucG5kLmR4Yy5teWJsdWVob3N0Lm1lgh13ZWJtYWlsLnBuZC5keGMubXlibHVlaG9zdC5tZYIVd3d3LmhhcnJpc2JpbGxpbmcuY29tghl3d3cucG5kLmR4Yy5teWJsdWVob3N0Lm1lghh3d3cucmVuZWVvZmFsbHRyYWRlcy5jb20wEwYDVR0gBAwwCjAIBgZngQwBAgEwggEDBgorBgEEAdZ5AgQCBIH0BIHxAO8AdQCi4wrkRe+9rZt+OO1HZ3dT14JbhJTXK14bLMS5UKRH5wAAAZVV8me1AAAEAwBGMEQCIEnw+E8t1SPW8zTslnvIJV1Kkl8bwvN6qcbKwx5g6keKAiAU/ADYZLB64FnEfeG6IyNJqBAplpGN3LblMzsNKzVuOQB2ABNK3xq1mEIJeAxv70x6kaQWtyNJzlhXat+u2qfCq+AiAAABlVXyaJIAAAQDAEcwRQIgVyqOdvH5E/xAAPfpoDesmevki3A/Z82TqBDvWEDANjQCIQDfWAeVZY6iNZbHo+2IzeAKBfSSt8ViTRGUgJNtbJbISDANBgkqhkiG9w0BAQsFAAOCAQEAjUCTM8iqaBineUSZRX9JFmRSxRgKv52EQydZXYLq8yW3mZRrF6nUcrIj7zdLVRa4IkFjtmFx+p8M62Ah5jBdAvyA7YhDUmICzhwAl999eEwiMmXp8CyYapqmc7c9QI+LEHay/hWmaEYCdDdaA0NrzuQazC5C+Kls3+Q9fczSHeemy7rM3ns2elByNmeUc5YNhbklKBwpV/ZRz97QphTiG2h3QlTBb5iLwUMULOMP967zVbA2C7QYCj0LDwj1zCOWQV0Wt6KW2SHDmT0XNkMa1k+rm0u+86vw5XXlnfMjJvFD7oWYm/N3/6hXj6VTSDPPQYkLEW7rna20o/O7kaja4QAA"; // const VALID_PRECERT_EXTRA_DATA: &str = "AAp/AAUKMIIFBjCCAu6gAwIBAgIRAIp9PhPWLzDvI4a9KQdrNPgwDQYJKoZIhvcNAQELBQAwTzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2VhcmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMjQwMzEzMDAwMDAwWhcNMjcwMzEyMjM1OTU5WjAzMQswCQYDVQQGEwJVUzEWMBQGA1UEChMNTGV0J3MgRW5jcnlwdDEMMAoGA1UEAxMDUjExMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuoe8XBsAOcvKCs3UZxD5ATylTqVhyybKUvsVAbe5KPUoHu0nsyQYOWcJDAjs4DqwO3cOvfPlOVRBDE6uQdaZdN5R2+97/1i9qLcT9t4x1fJyyXJqC4N0lZxGAGQUmfOx2SLZzaiSqhwmej/+71gFewiVgdtxD4774zEJuwm+UE1fj5F2PVqdnoPy6cRms+EGZkNIGIBloDcYmpuEMpexsr3E+BUAnSeI++JjF5ZsmydnS8TbKF5pwnnwSVzgJFDhxLyhBax7QG0AtMJBP6dYuC/FXJuluwme8f7rsIU5/agK70XEeOtlKsLPXzze41xNG/cLJyuqC0J3U095ah2H2QIDAQABo4H4MIH1MA4GA1UdDwEB/wQEAwIBhjAdBgNVHSUEFjAUBggrBgEFBQcDAgYIKwYBBQUHAwEwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUxc9GpOr0w8B6bJXELbBeki8m47kwHwYDVR0jBBgwFoAUebRZ5nu25eQBc4AIiMgaWPbpm24wMgYIKwYBBQUHAQEEJjAkMCIGCCsGAQUFBzAChhZodHRwOi8veDEuaS5sZW5jci5vcmcvMBMGA1UdIAQMMAowCAYGZ4EMAQIBMCcGA1UdHwQgMB4wHKAaoBiGFmh0dHA6Ly94MS5jLmxlbmNyLm9yZy8wDQYJKoZIhvcNAQELBQADggIBAE7iiV0KAxyQOND1H/lxXPjDj7I3iHpvsCUf7b632IYGjukJhM1yv4Hz/MrPU0jtvfZpQtSlET41yBOykh0FX+ou1Nj4ScOt9ZmWnO8m2OG0JAtIIE3801S0qcYhyOE2G/93ZCkXufBL713qzXnQv5C/viOykNpKqUgxdKlEC+Hi9i2DcaR1e9KUwQUZRhy5j/PEdEglKg3l9dtD4tuTm7kZtB8v32oOjzHTYw+7KdzdZiw/sBtnUfhBPORNuay4pJxmY/WrhSMdzFO2q3Gu3MUBcdo27goYKjL9CTF8j/Zz55yctUoVaneCWs/ajUX+HypkBTA+c8LGDLnWO2NKq0YD/pnARkAnYGPfUDoHR9gVSp/qRx+ZWghiDLZsMwhN1zjtSC0uBWiugF3vTNzYIEFfaPG7Ws3jDrAMMYebQ95JQ+HIBD/RPBuHRTBpqKlyDnkSHDHYPiNX3adPoPAcgdF3H2/W0rmoswMWgTlLn1Wu0mrks7/qpdWfS6PJ1jty80r2VKsM/Dj3YIDfbjXKdaFU5C+8bhfJGqU3taKauuz0wHVGT3eo6FlWkWYtbt4pgdamlwVeZEW+LM7qZEJEsMNPrfC03APKmZsJgpWCDWOKZvkZcvjVuYkQ4omYCTX5ohy+knMjdOmdH9c7SpqEWBDC86fiNex+O0XOMEZSa8DAAAVvMIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAwTzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2VhcmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJuZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBYMTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygch77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6UA5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sWT8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyHB5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UCB5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUvKBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWnOlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTnjh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbwqHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CIrU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkqhkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZLubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KKNFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7UrTkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdCjNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVcoyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPAmRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57demyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=; const _VALID_PRECERT_LEAF_INPUT: &str = "AAAAAAGKggjsawAB43aJADBzoMZJzGVt6UbAMXTSXFZv48OAW4RvUjaUN5gAAxowggMWoAMCAQICCESrJbcV2CxkMA0GCSqGSIb3DQEBCwUAMH8xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xFzAVBgNVBAoMDkdvb2dsZSBVSyBMdGQuMSEwHwYDVQQLDBhDZXJ0aWZpY2F0ZSBUcmFuc3BhcmVuY3kxIzAhBgNVBAMMGk1lcmdlIERlbGF5IEludGVybWVkaWF0ZSAxMB4XDTI1MDUzMDEwMjU0OFoXDTI1MDUzMTEwMjU0OFowdTELMAkGA1UEBhMCR0IxDzANBgNVBAcTBkxvbmRvbjEPMA0GA1UEChMGR29vZ2xlMSEwHwYDVQQLExhDZXJ0aWZpY2F0ZSBUcmFuc3BhcmVuY3kxITAfBgNVBAMTGGZsb3dlcnMtdG8tdGhlLXdvcmxkLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALSlbBWEhSrKZdynl37nbREG2XGUH3Wnhq/LhGaJMuF8aeuQn/6cu5+eTI6KzQSZtSNQsYaSs6hqd/jwD66V0zLSFla64NjHfSWE5UmiUNlZB6ZBTk78COva9SNbttIVEkA3JBEj96/zxdtOc1xVKGItZG6yWyPMSp3mn/7y2A5LNtp20LPBhlYgHx5P7netcOXdR7d9Sj26hVRdzGIGAVnqMIOXmY9x9c6Xyka6VJcWZJBuNqY5U3pJL08StCn68bp6XKRNNPv/cIVW10ja3RPVcHLzXoqLLvY1csk9+3TV/hFj3CiQ8hsBbU0vzuPhG85WGrbd6HgMi8gjJmoLLG8CAwEAAaOBtzCBtDAOBgNVHQ8BAf8EBAMCB4AwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDAYDVR0TAQH/BAIwADAfBgNVHSMEGDAWgBTpPAThgC/ChBMtJnCe8v0az6r+xjBeBgNVHREEVzBVghhmbG93ZXJzLXRvLXRoZS13b3JsZC5jb22COTAyLjExLjA5LjIwMjMuZ29vZ2xlX2FyZ29uMjAyNWgxLmZsb3dlcnMtdG8tdGhlLXdvcmxkLmNvbQAA"; // const VALID_PRECERT_EXTRA_DATA: &str = "AAVHMIIFQzCCAyugAwIBAgIIRKsltxXYLGQwDQYJKoZIhvcNAQELBQAwfzELMAkGA1UEBhMCR0IxDzANBgNVBAgMBkxvbmRvbjEXMBUGA1UECgwOR29vZ2xlIFVLIEx0ZC4xITAfBgNVBAsMGENlcnRpZmljYXRlIFRyYW5zcGFyZW5jeTEjMCEGA1UEAwwaTWVyZ2UgRGVsYXkgSW50ZXJtZWRpYXRlIDEwHhcNMjUwNTMwMTAyNTQ4WhcNMjUwNTMxMTAyNTQ4WjB1MQswCQYDVQQGEwJHQjEPMA0GA1UEBxMGTG9uZG9uMQ8wDQYDVQQKEwZHb29nbGUxITAfBgNVBAsTGENlcnRpZmljYXRlIFRyYW5zcGFyZW5jeTEhMB8GA1UEAxMYZmxvd2Vycy10by10aGUtd29ybGQuY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtKVsFYSFKspl3KeXfudtEQbZcZQfdaeGr8uEZoky4Xxp65Cf/py7n55MjorNBJm1I1CxhpKzqGp3+PAPrpXTMtIWVrrg2Md9JYTlSaJQ2VkHpkFOTvwI69r1I1u20hUSQDckESP3r/PF205zXFUoYi1kbrJbI8xKneaf/vLYDks22nbQs8GGViAfHk/ud61w5d1Ht31KPbqFVF3MYgYBWeowg5eZj3H1zpfKRrpUlxZkkG42pjlTekkvTxK0KfrxunpcpE00+/9whVbXSNrdE9VwcvNeiosu9jVyyT37dNX+EWPcKJDyGwFtTS/O4+EbzlYatt3oeAyLyCMmagssbwIDAQABo4HMMIHJMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMB8GA1UdIwQYMBaAFOk8BOGAL8KEEy0mcJ7y/RrPqv7GMF4GA1UdEQRXMFWCGGZsb3dlcnMtdG8tdGhlLXdvcmxkLmNvbYI5MDIuMTEuMDkuMjAyMy5nb29nbGVfYXJnb24yMDI1aDEuZmxvd2Vycy10by10aGUtd29ybGQuY29tMBMGCisGAQQB1nkCBAMBAf8EAgUAMA0GCSqGSIb3DQEBCwUAA4ICAQAIc8kagAzVUtsH30SvKLwsN+2tIzpphz2+s0/B5YJ6qwqTSp5agq9i4X2Z4T3BqTKvneB2/oG8XCYspFTMNRpHiXcJOJrI/2kWfMNH6+vIYUXyvfOXcPHgVJX1um9jMEgYyex9zcP1AWeHDSXHq34J+Sba8q2N2ze8ip/HEtakqoIjl6skqFTr+CVtyzVjN99Zb++NOVW0esdOjDCE5XItaJmzDWsGQs1BKjSNSeO6IVmBlg3SIaBTKXprb84hTFRh0AomrlJUd+t5qQBN5c+Vvx3dKNUqjonLkNSMOO2VELicqKV1UjCzFJfFauOmahdTiPSyROymKnkg3/eqzYKxD8HhOqcd8jNmPos7e2GrnxQb63FLZUta9ANYy2js54eU7A6VpcXeRsO3D0h31iHvnXMp/orjdFXh2kRxaphMBkdyXcuBihymAVXvDpsHZ607oqJeW75LXmHTpHYiNRTzz/Htz+/DqqcYsJGvWPT80WJDfQCb2aRB9BKrcLTzQmP/rjoIYYqvy+lqPks3E8/mDyTA69NHGIc/unczPjzs2uieVpjtXH+R+WTwshRpX+kCJUi1zJHG77I9GmkMmeDis9rpHqkBBtgTbOoB2jRpLBittQeJk8FtxZFptkeLrhQNFZEpcjRMupd2PRvGB2MBmyNrVvhtfhjzpEJCxcCshwALowAFzDCCBcgwggOwoAMCAQICAhABMA0GCSqGSIb3DQEBBQUAMH0xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xFzAVBgNVBAoMDkdvb2dsZSBVSyBMdGQuMSEwHwYDVQQLDBhDZXJ0aWZpY2F0ZSBUcmFuc3BhcmVuY3kxITAfBgNVBAMMGE1lcmdlIERlbGF5IE1vbml0b3IgUm9vdDAeFw0xNDA3MTcxMjI2MzBaFw0xOTA3MTYxMjI2MzBaMH8xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xFzAVBgNVBAoMDkdvb2dsZSBVSyBMdGQuMSEwHwYDVQQLDBhDZXJ0aWZpY2F0ZSBUcmFuc3BhcmVuY3kxIzAhBgNVBAMMGk1lcmdlIERlbGF5IEludGVybWVkaWF0ZSAxMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAweh0/v+a7vMDu/pjRTiB+q+NwcIsCWQdr0MDgfM7wVe/bEyKjVexq8eShZ0g8hkVCcWXxDexRnPepa9L6hQ5bdQ23GIFVdeVPg7eAff/tE8/983mTtJFY04N8Kr86cOsXrY9jeHZacrIhUGVQDqfnR1MPc7t8TUe3ZRXQ7xUq3RbIE+1JZ/j7faVws+QuIbEj/aAt0T+wmkbQ0UkKzHDbzEYtyfF3lwl7BqjCkwkYcURnva7kNgW5uRMW5lVv6LtNBa/bkpTyS+vrA0fC4vz01vg1PYcoF0o/GYt+liPuj7gOARwwBLe2eUbvx56Je+nRXhMSdBer43O4FJzYeyRMSYAXpcs9LhjkU+DYVgu0kVj/54DxS6KnKMmTFbBhrTsUrfmlc5CrhfseuAlcTHh2/SPLd4kLm6R6jBJiBNaFUgrBfwJE1UyiznlhujdOko6FMuX7vaPn2lyjCkfIZXSzOc9SukIRbG/xfrgQLlPw1milRGYG5lmrrVtOnxeSPjsqBXlvoaz025qJ+DixN7m4w8Sp8k2uMmMrVkorKI438Oc+fLFJGy7uygMtvmetJv9HXgIlTkHLBZMcIM3F0be28TewcuUOQc68/LmD4xQXwZ5YajFOUVPxTQRWOzMeFMvPjnDGHyUOfwP+I7pVxMdR43wY91Qsq0/56Bw6QXjhosCAwEAAaNQME4wHQYDVR0OBBYEFOk8BOGAL8KEEy0mcJ7y/RrPqv7GMB8GA1UdIwQYMBaAFPNfe3VJ43hBOWogtnxrTFzJPVhBMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggIBAAhYy9VF8ukuCZBqw589VcE2B6ZRQ2OGxukPEodzoOs/RyXYrzX7f4gENrSB8s9HgBgl3lTxP49ZIL89kW51MUHeXlnS3rv8P8Imch8VoW07ekYY6gVRY58dLLe5+q0bfgcPI6boGXw9dUm7plU/1dtBnOOZR39qBIG5D1HJ0wfYLLBc+WeCihrOZSB8+GttFnkiRdzyS0wXn5EYRzbn4vy4Y6S1yJsKwvNoOQoQWUuVyFbiWcd1ZDFomM+HpoF9GFhfyXbWgdnVEO8q036K0OSfW9SZyex/6PQ7F9/7m30N/YMAwcU4nJ6gvkNw3L94vT78IwjSULhmu8oDHAxJ/3enpUINqh8bakRNNmZTl0wtF5wwCYce5siRQPyp798jvUuIxuuuuShvWPPPwh5IdPGC0ezWBYkZsDsY23t5W+nLJfxRZqlF744RM81gMSoyNPRknfYWZAfLtezIOOnhGMBd7nyJapmHZVrn40nNgWbmjTTeo7SuiSqfI4UFMnHoYLVCvjZQUDl087tvJog3WrKEh9pnUdDyw6NeeO/jCxnVeAjr6ixEU5kK2B65YonA+ZxQgPggkrxhI6NAxjphfzvErcEpjYiieKaT75NohhHTtO3tC20CPtn2wuqINkgxl1JbGxvOcKkMNAOwlNX0EqoRQbmWWrgxTFL3ct7/wQCMAAXRMIIFzTCCA7WgAwIBAgIJAJ7TzLHRLKJyMA0GCSqGSIb3DQEBBQUAMH0xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xFzAVBgNVBAoMDkdvb2dsZSBVSyBMdGQuMSEwHwYDVQQLDBhDZXJ0aWZpY2F0ZSBUcmFuc3BhcmVuY3kxITAfBgNVBAMMGE1lcmdlIERlbGF5IE1vbml0b3IgUm9vdDAeFw0xNDA3MTcxMjA1NDNaFw00MTEyMDIxMjA1NDNaMH0xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xFzAVBgNVBAoMDkdvb2dsZSBVSyBMdGQuMSEwHwYDVQQLDBhDZXJ0aWZpY2F0ZSBUcmFuc3BhcmVuY3kxITAfBgNVBAMMGE1lcmdlIERlbGF5IE1vbml0b3IgUm9vdDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAKoWHPIgXtgaxWVIPNpCaj2y5Yj9t1ixe5PqjWhJXVNKAbpPbNHA/AoSivecBm3FTD9DfgW6J17mHb+cvbKSgYNzgTk5e2GJrnOP7yubYJpt2OCw0OILJD25NsApzcIiCvLA4aXkqkGgBq9FiVfisReNJxVu8MtxfhbVQCXZf0PpkW+yQPuF99V5Ri+grHbHYlaEN1C/HM3+t2yMR4hkd2RNXsMjViit9qCchIi/pQNt5xeQgVGmtYXyc92ftTMrmvduj7+pHq9DEYFt3ifFxE8v0GzCIE1xR/d7prFqKl/KRwAjYUcpU4vuazywcmRxODKuwWFVDrUBkGgCIVIjrMJWStH5i7WTSSTrVtOD/HWYvkXInZlSgcDvsNIG0pptJaEKSP4jUzI3nFymnoNZn6pnfdIII/XISpYSVeyl1IcdVMod8HdKoRew9CzW6f2n6KSKU5I8X5QEM1NUTmRLWmVi5c75/CvS/PzOMyMzXPf+fE2Dwbf4OcR5AZLTupqp8yCTqo7ny+cIBZ1TjcZjzKG4JTMaqDZ1Sg0T3mO/ZbbiBE3N8EHxoMWpw8OP50z1dtRRwj6qUZ2zLvngOb2EihlMO15BpVZC3Cg929c9Hdl65pUd4YrYnQBQB/rn6IvHo8zot8zElgOg22fHbViijUt3qnRggB40N30MXkYGwuJbAgMBAAGjUDBOMB0GA1UdDgQWBBTzX3t1SeN4QTlqILZ8a0xcyT1YQTAfBgNVHSMEGDAWgBTzX3t1SeN4QTlqILZ8a0xcyT1YQTAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4ICAQB3HP6jRXmpdSDYwkI9aOzQeJH4x/HDi/PNMOqdNje/xdNzUy7HZWVYvvSVBkZ1DG/ghcUtn/wJ5m6/orBn3ncnyzgdKyXbWLnCGX/V61PgIPQpuGo7HzegenYaZqWz7NeXxGaVo3/y1HxUEmvmvSiioQM1cifGtz9/aJsJtIkn5umlImenKKEV1Ly7R3Uz3Cjz/Ffac1o+xU+8NpkLF/67fkazJCCMH6dCWgy6SL3AOB6oKFIVJhw8SD8vptHaDbpJSRBxifMtcop/85XUNDCvO4zkvlB1vPZ9ZmYZQdyL43NA+PkoKy0qrdaQZZMq1Jdp+Lx/yeX255/zkkILp43jFyd44rZ+TfGEQN1WHlp4RMjvoGwOX1uGlfoGkRSgBRj7TBn514VYMbXu687RS4WY2v+kny3PUFv/ZBfYSyjoNZnU4Dce9kstgv+gaKMQRPcyL+4vZU7DV8nBIfNFilCXKMN/VnNBKtDV52qmtOsVghgai+QE09w15x7dg+44gIfWFHxNhvHKys+s4BBN8fSxAMLOsb5NGFHE8x58RAkmIYWHjyPM6zB5AUPw1b2A0sDtQmCqoxJZfZUKrzyLz8gS2aVujRYN13KklHQ3EKfkeKBG2KXVBe5rjMN/7Anf1MtXxsTY6O8qIuHZ5QlXhSYzE41yIlPlG6d7AGnTiBIgeg=="; #[test] fn test_deserialize_x509_leaf_input() { let leaf_input_bytes = base64::prelude::BASE64_STANDARD .decode(VALID_CERT_LEAF_INPUT) .unwrap(); let leaf = deserialize_leaf_input(leaf_input_bytes).unwrap(); assert_eq!(leaf.version, 0); assert_eq!(leaf.merkle_leaf_type, 0); assert_eq!(leaf.timestamp, 1740903704893); assert_eq!( leaf.log_entry_type, LogServerEntryType::X509LogServerEntryType ); if let LogServerEntry::X509Entry(data) = &leaf.entry { // somehow according to grok the first three bytes [00, 06, 94] are not part of a normal certificate. Skipping this now let cert = data[3..].to_vec(); let (cn, san) = extract_cn_and_san(cert).unwrap(); assert_eq!(cn, Some("webmail.pnd.dxc.mybluehost.me".to_string())); assert_eq!( san, vec![ "autodiscover.pnd.dxc.mybluehost.me", "cpanel.pnd.dxc.mybluehost.me", "cpcalendars.pnd.dxc.mybluehost.me", "cpcontacts.pnd.dxc.mybluehost.me", "harrisbilling.com", "mail.harrisbilling.com", "mail.pnd.dxc.mybluehost.me", "mail.reneeofalltrades.com", "pnd.dxc.mybluehost.me", "reneeofalltrades.com", "webdisk.pnd.dxc.mybluehost.me", "webmail.pnd.dxc.mybluehost.me", "www.harrisbilling.com", "www.pnd.dxc.mybluehost.me", "www.reneeofalltrades.com" ] ); } else { panic!("Expected X509Entry"); } } // #[test] // fn test_deserialize_leaf_input_precert() { // let leaf_input_bytes = base64::prelude::BASE64_STANDARD.decode(VALID_PRECERT_LEAF_INPUT).unwrap(); // let leaf = deserialize_leaf_input(leaf_input_bytes).unwrap(); // assert_eq!(leaf.version, 0); // assert_eq!(leaf.merkle_leaf_type, 0); // assert_eq!(leaf.timestamp, 0x008A8208EC6B); // assert_eq!( // leaf.log_entry_type, // LogServerEntryType::PrecertLogServerEntryType // ); // if let LogServerEntry::PrecertEntry(precert) = leaf.entry { // assert_eq!(precert.leaf_cert.len(), 32); // assert!(!precert.server_cert.is_empty()); // // Verify server_cert starts with a SEQUENCE tag (0x30) // assert_eq!(precert.server_cert[0], 0x30); // } else { // panic!("Expected PrecertEntry"); // } // } #[test] fn test_invalid_leaf_input_too_short() { let leaf_input = vec![0x00, 0x01]; // Only 2 bytes let result = deserialize_leaf_input(leaf_input); assert!(result.is_err()); assert_eq!( result.unwrap_err(), "Insufficient bytes for MerkleTreeLeaf header" ); } // #[test] // fn test_invalid_log_entry_type() { // let leaf_input = create_leaf_input(0, 0, 123, 2, &[0x01]); // Invalid type 2 // let result = deserialize_leaf_input(leaf_input); // assert!(result.is_err()); // assert_eq!(result.unwrap_err(), "Invalid LogServerEntryType"); // } } ``` ### `services/mrpf_cert_transparency/src/models/mod.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/models/mod.rs` ```rust pub mod logserver; pub mod logserverentry; pub mod logserverrange; pub mod merkletree; pub mod retryconfig; pub use logserver::LogServer; pub use logserverentry::LogServerEntry; pub use logserverrange::{LogServerRange, LogServerRangeState}; pub use merkletree::{MerkleTreeLeaf, MerkleTreeLeafHeader}; pub use retryconfig::RetryConfig; ``` ### `services/mrpf_cert_transparency/src/models/retryconfig.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/models/retryconfig.rs` ```rust use rand::Rng; use std::fmt; use std::ops::Range; use std::time::Duration; /// Configuration for retrying operations with rate limiting. /// /// This struct manages retry attempts with configurable backoff strategies, /// useful for handling server rate limits or transient failures. #[derive(PartialEq, Debug, Clone)] pub struct RetryConfig { /// Maximum number of retry attempts allowed before giving up. max_retries: usize, /// Initial delay before the first retry. initial_backoff: Duration, /// Range for the maximum backoff time. /// /// The actual backoff is randomized within this range (after being capped by the /// calculated value) to mitigate thundering herd problems in distributed systems. max_backoff: Range, /// Determines whether to use exponential or linear backoff. /// /// If `true`, backoff doubles with each retry (e.g., 3s, 6s, 12s). /// If `false`, backoff increases linearly using the `step` range. exponential_backoff: bool, /// Range for the step size used in backoff calculations. /// /// For linear backoff, this is the increment per retry (randomized within the range). /// For exponential backoff, it adds jitter. Helps avoid synchronized retries. step: Range, /// Number of retries attempted so far. current_retries: usize, } impl Default for RetryConfig { /// Creates a default `RetryConfig` with sensible values. fn default() -> Self { RetryConfig { max_retries: 2, initial_backoff: Duration::from_secs(3), max_backoff: Duration::from_secs(20)..Duration::from_secs(22), exponential_backoff: true, step: Duration::from_secs(1)..Duration::from_secs(2), current_retries: 0, } } } /// Errors that can occur during retry operations. #[derive(Debug)] pub enum RetryError { /// Returned when the maximum number of retries has been exceeded. MaxRetriesExceeded, /// Returned when a retry is rate-limited, including the duration to wait. RateLimited(Duration), /// Returned on ConnectionError ConnectionError, /// Undefined error UndefinedError(String), } impl fmt::Display for RetryError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { RetryError::MaxRetriesExceeded => write!(f, "Maximum retries exceeded"), RetryError::RateLimited(duration) => { write!(f, "Rate limited, backoff for {:?}", duration) } RetryError::ConnectionError => write!(f, "Connection error"), RetryError::UndefinedError(msg) => write!(f, "Undefined error: {}", msg), } } } impl std::error::Error for RetryError {} impl RetryConfig { /// Computes the backoff time after a server-indicated rate limit. /// /// Call this method after receiving a rate limit response (e.g., HTTP 429) to /// determine how long to wait before retrying. It respects server-provided delays /// if available, otherwise uses the configured backoff strategy. /// /// # Arguments /// * `server_retry_after` - Optional delay suggested by the server (e.g., from a `Retry-After` header). /// /// # Returns /// * `Err(RetryError::RateLimited(duration))` - Wait this long before retrying. /// * `Err(RetryError::MaxRetriesExceeded)` - No more retries are allowed. pub fn rate_limiter(&mut self, server_retry_after: Option) -> RetryError { if self.current_retries >= self.max_retries { return RetryError::MaxRetriesExceeded; } self.current_retries += 1; // Calculate backoff based on config or server hint let backoff = match server_retry_after { Some(duration) => duration, None => self.calculate_backoff(), }; // Return the full backoff duration on rate limit RetryError::RateLimited(backoff) } /// Calculates the backoff duration for the next retry attempt. /// /// Uses either exponential or linear backoff, capped by `max_backoff`, with randomization /// to prevent synchronized retries across multiple clients. fn calculate_backoff(&self) -> Duration { let mut rng = rand::rng(); let step = rng.random_range(self.step.start..self.step.end); let base = if self.exponential_backoff { // Cap to avoid overflow self.initial_backoff * 2u32.pow(self.current_retries.min(31) as u32) } else { self.initial_backoff + step * self.current_retries as u32 } .max(Duration::from_secs(1)); rng.random_range(base..self.max_backoff.start.max(base + Duration::from_secs(1))) .min(self.max_backoff.end) } /// Resets the retry state to its initial values. /// /// Use this after a successful operation or when retry context should be cleared. pub fn reset(&mut self) { self.current_retries = 0; } } #[cfg(test)] mod tests { use super::*; use std::time::Duration; // Helper to create a config with specific settings fn setup_config(exponential: bool, max_retries: usize) -> RetryConfig { RetryConfig { max_retries, initial_backoff: Duration::from_secs(2), max_backoff: Duration::from_secs(10)..Duration::from_secs(12), exponential_backoff: exponential, step: Duration::from_secs(1)..Duration::from_secs(2), current_retries: 0, } } #[test] fn test_exponential_backoff_calculation() { let config = setup_config(true, 3); let backoff_0 = config.calculate_backoff(); // 1st retry assert!(backoff_0 >= Duration::from_secs(2) && backoff_0 <= Duration::from_secs(12)); let mut config = setup_config(true, 3); config.current_retries = 1; let backoff_1 = config.calculate_backoff(); // 2nd retry assert!(backoff_1 >= Duration::from_secs(4) && backoff_1 <= Duration::from_secs(12)); config.current_retries = 2; let backoff_2 = config.calculate_backoff(); // 3rd retry assert!(backoff_2 >= Duration::from_secs(8) && backoff_2 <= Duration::from_secs(12)); } #[test] fn test_linear_backoff_calculation() { let mut config = setup_config(false, 3); let backoff_0 = config.calculate_backoff(); // 1st retry assert!(backoff_0 >= Duration::from_secs(2) && backoff_0 <= Duration::from_secs(12)); config.current_retries = 1; let backoff_1 = config.calculate_backoff(); // 2nd retry assert!(backoff_1 >= Duration::from_secs(3) && backoff_1 <= Duration::from_secs(12)); config.current_retries = 2; let backoff_2 = config.calculate_backoff(); // 3rd retry assert!(backoff_2 >= Duration::from_secs(4) && backoff_2 <= Duration::from_secs(12)); } #[test] fn test_rate_limiter_first_attempt() { let mut config = setup_config(true, 3); let result = config.rate_limiter(None); match result { RetryError::RateLimited(duration) => { assert!(duration >= Duration::from_secs(2) && duration <= Duration::from_secs(12)); } _ => panic!("Expected RateLimited error"), } assert_eq!(config.current_retries, 1); } #[test] fn test_rate_limiter_server_retry_after() { let mut config = setup_config(true, 3); let server_delay = Duration::from_secs(5); let result = config.rate_limiter(Some(server_delay)); match result { RetryError::RateLimited(duration) => { assert_eq!(duration, server_delay); } _ => panic!("Expected RateLimited error"), } assert_eq!(config.current_retries, 1); // Second call with same server delay let result = config.rate_limiter(Some(server_delay)); match result { RetryError::RateLimited(duration) => { assert_eq!(duration, server_delay); } _ => panic!("Expected RateLimited error"), } assert_eq!(config.current_retries, 2); } #[test] fn test_rate_limiter_rate_limited() { let mut config = setup_config(true, 3); let result = config.rate_limiter(None); match result { RetryError::RateLimited(duration) => { assert!(duration >= Duration::from_secs(2) && duration <= Duration::from_secs(12)); } _ => panic!("Expected RateLimited error"), } assert_eq!(config.current_retries, 1); let result = config.rate_limiter(None); match result { RetryError::RateLimited(duration) => { assert!(duration >= Duration::from_secs(4) && duration <= Duration::from_secs(12)); } _ => panic!("Expected RateLimited error"), } assert_eq!(config.current_retries, 2); } #[test] fn test_rate_limiter_max_retries_exceeded() { let mut config = setup_config(true, 2); // Max 2 retries config.rate_limiter(None); // 1st retry config.rate_limiter(None); // 2nd retry assert_eq!(config.current_retries, 2); let result = config.rate_limiter(None); // 3rd retry exceeds max match result { RetryError::MaxRetriesExceeded => (), _ => panic!("Expected MaxRetriesExceeded error"), } assert_eq!(config.current_retries, 2); // Should not increment past max } #[test] fn test_reset() { let mut config = setup_config(true, 3); config.rate_limiter(None); assert_eq!(config.current_retries, 1); config.reset(); assert_eq!(config.current_retries, 0); } } ``` ### `services/mrpf_cert_transparency/src/utils/extractcert.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/utils/extractcert.rs` ```rust #[derive(Debug)] pub enum Error { InvalidData, InvalidTag, InvalidLength, CommonNameNotFound, InvalidString, } use tracing::error; impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) } } impl std::error::Error for Error {} // OID constants (DER encoded value only, excluding tag/length) const OID_CN: &[u8] = b"\x55\x04\x03"; const OID_SAN: &[u8] = b"\x55\x1D\x11"; #[inline(always)] #[expect(clippy::type_complexity)] fn parse_tlv(data: &[u8]) -> Result<(u8, usize, &[u8], &[u8]), Error> { if data.len() < 2 { error!("InvalidData: data.len()={}", data.len()); return Err(Error::InvalidData); } let tag = data[0]; let length_byte = data[1]; let mut offset = 2; // After tag and length byte let mut length = length_byte as usize; if length_byte >= 0x80 { let len_bytes = (length_byte & 0x7F) as usize; if offset + len_bytes > data.len() { return Err(Error::InvalidLength); } length = 0; for &byte in &data[offset..offset + len_bytes] { length = (length << 8) + byte as usize; } offset += len_bytes; } if offset + length > data.len() { error!( "InvalidLength: offset={}, length={}, data.len()={}", offset, length, data.len() ); error!("Data: {:?}", &data[..std::cmp::min(data.len(), 16)]); return Err(Error::InvalidLength); } let value = &data[offset..offset + length]; let remaining = &data[offset + length..]; Ok((tag, length, value, remaining)) } #[inline(always)] fn skip_tlv(data: &[u8]) -> Result<&[u8], Error> { let (_, _, _, remaining) = parse_tlv(data)?; Ok(remaining) } /// Extract common name from given certificate subject fn extract_cn(subject: &[u8]) -> Result { let mut data = subject; while !data.is_empty() { let (tag, _, value, rest) = parse_tlv(data)?; // debug!("Subject Tag: {:02X}", tag); if tag != 0x31 { // SET return Err(Error::InvalidTag); } let (seq_tag, _, seq_value, _seq_rest) = parse_tlv(value)?; // debug!(" Sequence Tag: {:02X}", seq_tag); if seq_tag != 0x30 { // SEQUENCE (AttributeTypeAndValue) return Err(Error::InvalidTag); } let (oid_tag, _, oid_value, oid_rest) = parse_tlv(seq_value)?; // debug!(" OID Tag: {:02X}, OID: {:02X?}", oid_tag, oid_value); if oid_tag != 0x06 { // OBJECT IDENTIFIER data = rest; // Advance past entire SET continue; } if oid_value == OID_CN { // 2.5.4.3 (CN) let (_str_tag, _, str_value, _) = parse_tlv(oid_rest)?; // debug!(" String Tag: {:02X}", str_tag); return std::str::from_utf8(str_value) .map(|s| s.to_string()) .map_err(|_| Error::InvalidString); } data = rest; // Advance past entire SET for non-CN } Err(Error::CommonNameNotFound) } fn extract_san(extensions: &[u8]) -> Result, Error> { let mut data = extensions; let mut san_names = Vec::new(); while !data.is_empty() { let (tag, _, value, rest) = parse_tlv(data)?; if tag != 0x30 { // Extension SEQUENCE error!(" InvalidTag extension SEQUENCE: {:02X}", tag); return Err(Error::InvalidTag); } let (oid_tag, _, oid_value, mut inner_rest) = parse_tlv(value)?; // debug!(" OID Tag: {:02X}, OID: {:02X?}", oid_tag, oid_value); if oid_tag != 0x06 { // OBJECT IDENTIFIER data = rest; continue; } if oid_value == OID_SAN { // 2.5.29.17 (SAN) // debug!(" Found SAN OID"); loop { let (next_tag, next_length, next_value, next_rest) = parse_tlv(inner_rest)?; // debug!( // " Next Tag: {:02X}, Length: {}, Value start: {:02X?}", // next_tag, // next_length, // &next_value[..std::cmp::min(next_value.len(), 8)] // ); if next_tag == 0x01 { // BOOLEAN (criticality) // debug!(" Found criticality:"); if next_length != 1 || next_value.len() != 1 { error!(" Invalid criticality BOOLEAN length: {}", next_length); return Err(Error::InvalidData); } // debug!(" Criticality: {:02X}", next_value[0]); inner_rest = next_rest; // Advance past BOOLEAN (3 bytes total) continue; // Restart loop to parse next TLV (OCTET STRING) } else if next_tag == 0x04 { // OCTET STRING let (seq_tag, _, seq_content, _) = parse_tlv(next_value)?; // debug!( // " Octet Tag: {:02X}, Sequence Tag: {:02X}", // next_tag, seq_tag // ); if seq_tag != 0x30 { // SEQUENCE of GeneralNames error!(" InvalidTag GeneralNames SEQUENCE: {:02X}", seq_tag); return Err(Error::InvalidTag); } san_names.extend(extract_dns_names(seq_content)?); return Ok(san_names); } else { error!( " InvalidTag expecting BOOLEAN or OCTET STRING: {:02X}", next_tag ); return Err(Error::InvalidTag); } } } data = rest; } Ok(san_names) } // Helper function to extract dNSNames from GeneralNames SEQUENCE fn extract_dns_names(seq_content: &[u8]) -> Result, Error> { let mut san_data = seq_content; let mut san_names = Vec::new(); while !san_data.is_empty() { let (san_tag, _, san_value, rest_san) = parse_tlv(san_data)?; // debug!( // " SAN Tag: {:02X}, Value start: {:02X?}", // san_tag, // &san_value[..std::cmp::min(san_value.len(), 8)] // ); // Other tags I've seen, but not handling yet: // 0x81: rfc822Name (email) // 0x82: IA5String (dNSName) // 0x87: IP address // 0x86: URI // The full list of GeneralName types in SAN (per RFC 5280, Section 4.2.1.6) is: // otherName (0xA0, with inner OID and value) // rfc822Name (0x81 or sometimes 0x82) This is the email address of the issuer // dNSName (0x82) // x400Address (0xA3, rare, legacy X.400 protocol) // directoryName (0xA4 or 0xA0 depending on tools, typically a SEQUENCE) // ediPartyName (0xA5, extremely rare) // uniformResourceIdentifier (0x86) // iPAddress (0x87) // registeredID (0x88 or 0x83) if san_tag == 0x82 { // IA5String (dNSName) let name = std::str::from_utf8(san_value) .map(|s| s.to_string()) .map_err(|_| Error::InvalidString)?; san_names.push(name); } san_data = rest_san; } Ok(san_names) } /// Extract the Common name and Subject Alternative Names from a certificate pub fn extract_cn_and_san(cert: Vec) -> Result<(Option, Vec), Error> { // debug!("{:02X?}...", &cert); let (tag, _, value, _) = parse_tlv(&cert)?; // debug!("{:02X}: {:02X?}...", tag, &value[..4]); if tag != 0x30 { // Certificate SEQUENCE error!( "InvalidTag cert SEQUENCE: {:02X}: {:02X?}...", tag, &value[..4] ); return Err(Error::InvalidTag); } let (tbs_tag, _, mut data, _) = parse_tlv(value)?; // debug!("{:02X}: {:02X?}...", tbs_tag, &data[..4]); if tbs_tag != 0x30 { // tbsCertificate SEQUENCE error!( "InvalidTag tbsCertificate SEQUENCE: {:02X}: {:02X?}...", tbs_tag, &data[..4] ); return Err(Error::InvalidTag); } if !data.is_empty() && data[0] == 0xA0 { // Skip version data = skip_tlv(data)?; } // serialNumber, signature, issuer, validity data = skip_tlv(data)?; // serialNumber data = skip_tlv(data)?; // signature data = skip_tlv(data)?; // issuer data = skip_tlv(data)?; // validity // Parse subject let (subject_tag, _, subject_value, rest) = parse_tlv(data)?; if subject_tag != 0x30 { return Err(Error::InvalidTag); } // Extract CN let cn = match extract_cn(subject_value) { Ok(cn) => Some(cn), Err(Error::CommonNameNotFound) | Err(Error::InvalidString) => None, Err(e) => { let cn_lossy = String::from_utf8_lossy(subject_value).to_string(); error!("Failed to extract cn from {:?}: {:?}", cn_lossy, e); None } }; data = skip_tlv(rest)?; // Skip subjectPublicKeyInfo // Iterate through remaining optional fields to find extensions // and extract SAN names let mut san_names = Vec::new(); while !data.is_empty() { let (tag, _, value, rest) = parse_tlv(data)?; if tag == 0xA3 { // Extensions let (ext_tag, _, ext_value, _) = parse_tlv(value)?; // debug!("{:02X?}...", &ext_value[..4]); if ext_tag != 0x30 { return Err(Error::InvalidTag); } // debug!("Extracting SAN names"); san_names = extract_san(ext_value)?; break; } else if tag == 0xA1 || tag == 0xA2 { // debug!("Found issuerUniqueID or subjectUniqueID"); // Skip issuerUniqueID or subjectUniqueID data = rest; } else { break; } } Ok((cn, san_names)) } ``` ### `services/mrpf_cert_transparency/src/utils/fqdn.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/utils/fqdn.rs` ```rust /// Cleanup the FQDN by removing any whitespace, newlines, dots, and wildcard characters /// TODO: Probably should be in mrpf engine or utils crate /// UPDATE: duplicated this function to mrpf_cert_parser crate, and renamed it to clean_fqdn pub fn fqdn_cleanup(fqdn: &mut String) { // Remove any whitespace and newlines fqdn.retain(|c| !c.is_whitespace() && c != '\n'); // Remove any leading dots and wildcards while fqdn.starts_with('.') || fqdn.starts_with('*') || fqdn.starts_with('?') { fqdn.remove(0); } // Remove any trailing dots while fqdn.ends_with('.') { fqdn.pop(); } } ``` ### `services/mrpf_cert_transparency/src/utils/mod.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/utils/mod.rs` ```rust pub mod extractcert; pub mod fqdn; pub mod transformations; pub use extractcert::extract_cn_and_san; pub use fqdn::fqdn_cleanup; pub use transformations::{aggregate_ranges, split_ranges_to_scan}; ``` ### `services/mrpf_cert_transparency/src/utils/transformations.rs` **Absolute path:** `/root/mrpf/services/mrpf_cert_transparency/src/utils/transformations.rs` ```rust use crate::{LogServerRange, LogServerRangeState}; use std::collections::HashSet; use tracing::debug; /// Aggregate log server ranges for distributed scanner tasks /// /// This will iterate through any existing log server ranges and determine /// what ranges need to be scanned. /// /// To minimize the amount of records in the database, we will combine any adjacent or overlapping /// ranges into a single big range. All existing ranges that are currently in progress /// will be re-added for scanning, as we will assume they failed in a previous run. /// /// # Arguments /// /// * `log_server` - The log server to generate the ranges for /// * `existing_ranges` - The existing ranges for the log server. These MUST be sorted from low to high (which is tackled by DynamoDB already) /// * `tree_size` - The total size of available entries /// /// # Returns /// /// * `ranges_to_delete` - The ranges that need to be removed from the database. These will be older ranges that are now aggregated into a single range /// * `ranges_to_add` - The ranges that need to be added to the database. These will have either the state Done or Pending /// * `ranges_to_scan` - The ranges that need to be scanned still. These will only have the state Pending pub fn aggregate_ranges( log_server: &str, mut existing_ranges: Vec, tree_size: i64, ) -> ( HashSet, Vec, Vec, ) { let mut ranges_to_delete = HashSet::new(); let mut ranges_to_add = Vec::new(); let mut ranges_to_scan = Vec::new(); // We will first remove all pending ranges from the existing ranges. existing_ranges.retain(|r| { if r.state() == LogServerRangeState::Pending { ranges_to_delete.insert(r.clone()); false } else { true } }); if existing_ranges.is_empty() { debug!("No existing completed ranges found, adding a single range for the entire tree"); ranges_to_add.push(LogServerRange::new( log_server.to_string(), LogServerRangeState::Pending, 0, tree_size, )); ranges_to_scan.push(LogServerRange::new( log_server.to_string(), LogServerRangeState::Pending, 0, tree_size, )); } else { debug!("Existing ranges: {:?}", existing_ranges); // Merge all adjacent ranges into a single range // // We will merge all ranges together first. This will ensure we're able to // change the range sizes that a single lambda can process more easily. let mut all_ranges = Vec::with_capacity(existing_ranges.len()); let mut last_end = -1; // Sort existing ranges by start index. This SHOULD already be ok due to dynamodb sorting // but is an extra failsafe. let mut existing_ranges = existing_ranges; existing_ranges.sort_by_key(|r| r.start()); // Add range that starts at 0 if it's not present if existing_ranges[0].start() > 0 { debug!("Adding range 0-{}", existing_ranges[0].start() - 1); all_ranges.push(LogServerRange::new( log_server.to_string(), LogServerRangeState::Pending, 0, existing_ranges[0].start() - 1, )); } // Iterate through all existing ranges and fill the gaps for mut range in existing_ranges { if range.log_server() != log_server { panic!( "Existing range log server {} does not match {}", range.log_server(), log_server ); } if range.start() > last_end + 1 { debug!("Filling gap between {} and {}", last_end, range.start()); all_ranges.push(LogServerRange::new( log_server.to_string(), LogServerRangeState::Pending, last_end + 1, range.start() - 1, )); } if range.end() > tree_size { debug!( "Range end {} is larger than tree size {}", range.end(), tree_size ); ranges_to_delete.insert(range.clone()); range.set_end(tree_size); } last_end = range.end(); all_ranges.push(range); } // Add a range till the tree size if last_end < tree_size { debug!( "Adding range {}-{} till end of tree", last_end + 1, tree_size ); all_ranges.push(LogServerRange::new( log_server.to_string(), LogServerRangeState::Pending, last_end + 1, tree_size, )); } // Now we iterate through all ranges and combine them if they are adjacent let mut current = all_ranges[0].clone(); for next in &all_ranges[1..] { if let Some(combined) = current.combine(next) { // Remove the smaller ranges and set combined range as current debug!( "Combining ranges {}-{} {} with {}-{} {}.", current.start(), current.end(), current.state(), next.start(), next.end(), next.state() ); ranges_to_delete.insert(current.clone()); ranges_to_delete.insert(next.clone()); ranges_to_add.retain(|r| r != ¤t && r != &next.clone()); current = combined; } else { // Add the current range, which can potentially be a previously combined range ranges_to_add.push(current.clone()); if current.state() == LogServerRangeState::Pending { ranges_to_scan.push(current); } current = next.clone(); } } // Add the last range if not already there if ranges_to_add.last() != Some(¤t) { ranges_to_add.push(current.clone()); if current.state() == LogServerRangeState::Pending { ranges_to_scan.push(current); } } debug!( "{} ranges after aggregating and filling the gaps until end of tree", ranges_to_add.len() ); } (ranges_to_delete, ranges_to_add, ranges_to_scan) } /// Split all ranges_to_add into smaller ranges that can be processed by a single lambda /// /// /// We need to limit the maximum amount of ranges to scan, otherwise /// we will run out of memory in the lambda. Especially when the tree /// size is very large, which it will be initially. /// /// Our lambda is on a schedule so should eventually catch up to the /// latest entries. /// /// Another reason for limiting the amount of ranges to scan is to avoid /// generating too many SQS messages, especially when we're behind on /// scanning the log servers. /// /// # Arguments /// /// * `ranges_to_add` - The ranges that need to be added to the database /// * `ranges_to_scan` - The ranges that need to be scanned /// * `split_size` - The maximum size of a single range /// * `max_range_count` - The maximum amount of ranges to scan /// /// # Returns /// /// Tuple of the following: /// /// * `ranges_to_add` - The ranges that need to be added to the database /// * `split_ranges_to_scan` - The ranges that need to be scanned, split into smaller ranges pub fn split_ranges_to_scan( mut ranges_to_add: Vec, ranges_to_scan: Vec, split_size: i64, max_range_count: usize, ) -> (Vec, Vec) { debug!( "Splitting {} ranges to scan into smaller chunks of max {} entries", ranges_to_scan.len(), split_size ); let mut split_ranges_to_scan = Vec::with_capacity(max_range_count); for range in ranges_to_scan { let remaining_capacity = max_range_count - split_ranges_to_scan.len(); if remaining_capacity > 0 { let split_ranges = range.split(split_size, remaining_capacity); debug!( "Split range {}-{} into {} smaller ranges", range.start(), range.end(), split_ranges.len() ); // Collect split ranges once let split_vec: Vec = split_ranges.into_iter().take(remaining_capacity).collect(); // Add to split_ranges_to_scan split_ranges_to_scan.extend(split_vec.iter().cloned()); // Remove the original range and add only new split ranges that aren't already present ranges_to_add.retain(|r| r != &range); for split_range in &split_vec { if !ranges_to_add.contains(split_range) { ranges_to_add.push(split_range.clone()); } } // Only add what fits within the remaining capacity // let to_add = split_ranges.into_iter().take(remaining_capacity); // split_ranges_to_scan.extend(to_add.clone()); // ranges_to_add.retain(|r| r != &range); // ranges_to_add.extend(to_add); } else { debug!( "Reached maximum amount of ranges ({}) to scan for log server", max_range_count ); break; } } debug!( "{} ranges to add, {} ranges to scan", ranges_to_add.len(), split_ranges_to_scan.len() ); (ranges_to_add, split_ranges_to_scan) } ``` ### `services/mrpf_scanner_api/README.md` **Absolute path:** `/root/mrpf/services/mrpf_scanner_api/README.md` ```markdown # MRPF Scanner WebSocket API The WebSocket API provides a way for clients to start, stop and monitor MRPF scans. We will maintain state of connected clients and running scans on the server side. Each client can see what scans are running and they will dynamically get updates pushed as they come in. *TODO* what happens when a client disconnects? Do we want to keep the scan running and store it in memory or do we want to cancel it? *TODO* Do we want the scanner to only allow a single scan to happen at a time? Or do we want to allow multiple scans to happen at the same time? ## Copilot suggested message structure Creating a JSON message format for managing network scans involves defining a clear and consistent structure for your commands and responses. Here's a suggested format that you can adapt to your needs: ### JSON Message Format #### 1. Starting a Scan ```json { "action": "start", "scan_type": "TcpSyn", // or "TLS", "HTTP", "DNS" "parameters": { "target": "192.168.1.1", "port": 80, "options": { "timeout": 1000, "retries": 3 } } } ``` #### 2. Pausing a Scan ```json { "action": "pause", "scan_id": "12345" } ``` #### 3. Stopping a Scan ```json { "action": "stop", "scan_id": "12345" } ``` #### 4. Retrieving List of Running Scans ```json { "action": "list" } ``` #### 5. Retrieving Details of a Specific Scan ```json { "action": "details", "scan_id": "12345" } ``` ### Example Responses #### 1. Scan Started ```json { "status": "success", "message": "Scan started", "scan_id": "12345" } ``` #### 2. Scan Paused ```json { "status": "success", "message": "Scan paused", "scan_id": "12345" } ``` #### 3. Scan Stopped ```json { "status": "success", "message": "Scan stopped", "scan_id": "12345" } ``` #### 4. List of Running Scans ```json { "status": "success", "scans": [ { "scan_id": "12345", "scan_type": "TcpSyn", "target": "192.168.1.1", "status": "running" }, { "scan_id": "67890", "scan_type": "HTTP", "target": "example.com", "status": "paused" } ] } ``` #### 5. Scan Details ```json { "status": "success", "scan": { "scan_id": "12345", "scan_type": "TcpSyn", "target": "192.168.1.1", "status": "running", "results": [ { "timestamp": "2024-12-17T15:54:50Z", "data": "Open port 80" } ] } } ``` ### Explanation - **Action**: Specifies the action to be taken (`start`, `pause`, `stop`, `list`, `details`). - **Scan Type**: Indicates the type of scan (`TcpSyn`, `TLS`, `HTTP`, `DNS`). - **Parameters**: Contains the necessary parameters for starting a scan, such as target IP, port, and additional options. - **Scan ID**: A unique identifier for each scan, used to manage and retrieve specific scans. - **Status**: Indicates the success or failure of the action. - **Message**: Provides additional information about the action taken. - **Results**: Contains the results of the scan, which can be pushed back as messages during the scan. This structure should help you manage your network scans effectively and ensure clear communication between your websocket server and clients. If you have any specific requirements or need further customization, feel free to ask! ``` ### `services/mrpf_scanner_api/src/main.rs` **Absolute path:** `/root/mrpf/services/mrpf_scanner_api/src/main.rs` ```rust mod messages; mod progress; use messages::{ScanMessage, ScanState, ScanType}; use mrpf_engine::ScannerConfig; use mrpf_engine::utils::get_interface; use std::net::{Ipv4Addr, TcpListener, TcpStream}; use std::sync::{Arc, Mutex}; use std::thread::spawn; use tungstenite::Message; use tungstenite::{ accept_hdr, handshake::server::{Request, Response}, }; fn main() { let interface = get_interface(Some("en0")); let src_mac = interface.mac.expect("No MAC address found"); // Home IPs let src_ip = Ipv4Addr::new(192, 168, 2, 244); // Spoofed IP let router_ip = Ipv4Addr::new(192, 168, 2, 254); let config = ScannerConfig::builder() .interface(interface) .src_mac(src_mac) .src_ip(src_ip) .router_ip(router_ip) .build(); let addr = "127.0.0.1:6666"; let listener = TcpListener::bind(addr).unwrap(); let state: Arc>> = Arc::new(Mutex::new(Vec::new())); println!("running websocket server: ws://{:}/index.html", addr); for stream in listener.incoming() { println!("New client connected"); let config_clone = config.clone(); let state_clone = Arc::clone(&state); spawn(move || { let stream = stream.unwrap(); handle_client(config_clone, stream, &state_clone); }); } } /// Handle a single client connection fn handle_client(config: ScannerConfig, stream: TcpStream, state: &Arc>>) { let callback = |_req: &Request, response: Response| Ok(response); let websocket = accept_hdr(stream, callback).unwrap(); let websocket = Arc::new(Mutex::new(websocket)); loop { // Loop read messages from the client let ws_clone = Arc::clone(&websocket); let msg = { let mut ws = ws_clone.lock().unwrap(); ws.read() }; match msg { Ok(msg) => { let text = msg.to_text().unwrap(); match serde_json::from_str::(text) { Ok(scan_message) => match scan_message { ScanMessage::TcpSynScan(scan) => { println!("Received TcpSynScan message: {:?}", scan); { let mut state_lock = state.lock().unwrap(); state_lock.push(ScanState::new(ScanType::TcpSyn)); } scan.run(ws_clone, config.clone()); } ScanMessage::TlsScan(scan) => { println!("Received TLS scan message: {:?}", scan); { let mut state_lock = state.lock().unwrap(); state_lock.push(ScanState::new(ScanType::Tls)); } scan.run(ws_clone, config.clone()); } ScanMessage::Http1Scan(scan) => { println!("Received HTTP1 scan message: {:?}", scan); // println!("{}", String::from_utf8_lossy(&scan.template)); { let mut state_lock = state.lock().unwrap(); state_lock.push(ScanState::new(ScanType::Http1)); } scan.run(ws_clone, config.clone()); } ScanMessage::ScanList(_) => { println!("Received ScanList message"); let mut websocket = websocket.lock().unwrap(); let state_lock = state.lock().unwrap(); let scans_json = serde_json::to_string(&*state_lock).unwrap(); websocket .send(Message::text(format!("{{\"scans\": {}}}", scans_json))) .unwrap(); } }, Err(e) => { println!( "Error parsing message: {:?}.\noriginal message:\n{}", e, text ); break; } } } Err(tungstenite::Error::Protocol( tungstenite::error::ProtocolError::ResetWithoutClosingHandshake, )) => { // TODO: Should we cache all results and send them to the client when they reconnect? println!("Client disconnected without closing handshake"); break; } Err(e) => { println!("Error reading message: {:?}", e); break; } } println!("Finished scanning\n"); } } ``` ### `services/mrpf_scanner_api/src/messages.rs` **Absolute path:** `/root/mrpf/services/mrpf_scanner_api/src/messages.rs` ```rust use super::progress::{ Http1ProgressHandlerWebSocket, TcpSynProgressHandlerWebSocket, TlsProgressHandlerWebSocket, }; use chrono::Utc; use mrpf_engine::ScannerConfig; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use mrpf_http1_scanner::{Http1Scanner, Template, TemplateVariables}; use mrpf_tcp_syn_scanner::TcpSynScanner; use mrpf_tls_scanner::TlsScanner; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::TcpStream; use std::sync::{Arc, Mutex}; use tungstenite::WebSocket; #[derive(Serialize, Deserialize, Debug)] pub(crate) enum ScanType { TcpSyn, Tls, Http1, } #[derive(Serialize, Deserialize, Debug)] pub(crate) struct ScanState { start_time: i64, scan_type: ScanType, } impl ScanState { pub(crate) fn new(scan_type: ScanType) -> Self { ScanState { start_time: Utc::now().timestamp(), scan_type, } } } #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "scan_type")] pub(crate) enum ScanMessage { #[serde(rename = "tcp_syn")] TcpSynScan(TcpSynScan), #[serde(rename = "tls")] TlsScan(TlsScan), #[serde(rename = "http1")] Http1Scan(Http1Scan), #[serde(rename = "list")] ScanList(ScanList), } #[derive(Serialize, Deserialize, Debug)] pub(crate) struct ScanList { pub(crate) scans: Vec, } /// Message to start a TCP SYN scan #[derive(Serialize, Deserialize, Debug)] pub(crate) struct TcpSynScan { pub ipv4_addresses: Vec, pub ports: String, } impl TcpSynScan { pub fn run(&self, websocket: Arc>>, config: ScannerConfig) { // For instance use this payload // { // "scan_type": "tcp_syn", // "ipv4_addresses": ["136.144.153.226"], // "ports": "443" // } let ipv4_addresses: Ipv4AddressRanges = self.ipv4_addresses.clone().into(); let ports: PortRanges = self.ports.clone().into(); let scanner = TcpSynScanner::new(config.clone()); let progress_handler = TcpSynProgressHandlerWebSocket::new( Arc::clone(&websocket), ipv4_addresses.len(), ports.len(), ); let _open_ports = scanner.run_with_handle(ipv4_addresses, ports, progress_handler); } } /// Message to start a TLS scan #[derive(Serialize, Deserialize, Debug)] pub(crate) struct TlsScan { /// Vec<(Vec, Vec, Ports)> pub targets: Vec<(Vec, Vec, String)>, } impl TlsScan { pub fn run(&self, websocket: Arc>>, config: ScannerConfig) { // For instance use this payload // { // "scan_type": "tls", // "targets": [ // [["example.com", "www.mortimer.nl"], ["136.144.153.226", "93.184.215.14"], "443"] // ] // } let targets: Vec<(Vec, Ipv4AddressRanges, PortRanges)> = self .targets .iter() .map(|(snis, ip_addresses, port)| { ( snis.clone(), ip_addresses.clone().into(), port.clone().into(), ) }) .collect(); println!("{:?}", targets); let scanner = TlsScanner::new(config.clone()); let progress_handler = TlsProgressHandlerWebSocket::new(Arc::clone(&websocket), &targets); let _results = scanner.run_with_handle(targets, progress_handler); } } /// Message to start a Http1 scan #[derive(Serialize, Deserialize, Debug)] pub(crate) struct Http1Scan { pub ipv4_addresses: Vec, pub ports: String, pub snis: Vec, #[serde(deserialize_with = "string_to_bytes")] pub template: Template, #[serde(deserialize_with = "variables_to_bytes")] pub variables: TemplateVariables, } fn string_to_bytes<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, { let s: String = serde::Deserialize::deserialize(deserializer)?; Ok(s.into_bytes()) } #[expect(clippy::type_complexity)] fn variables_to_bytes<'de, D>(deserializer: D) -> Result, Vec>>, D::Error> where D: serde::Deserializer<'de>, { let map: HashMap> = serde::Deserialize::deserialize(deserializer)?; let mut result = HashMap::new(); for (key, value) in map { let key_bytes = key.into_bytes(); let value_bytes = value.into_iter().map(|s| s.into_bytes()).collect(); result.insert(key_bytes, value_bytes); } Ok(result) } impl Http1Scan { pub fn run(&self, websocket: Arc>>, config: ScannerConfig) { // For instance use this payload // { // "scan_type": "http1", // "ipv4_addresses": ["136.144.153.226"], // "ports": "443", // "sni": null, // "template": "GET / HTTP/1.1\r\nHost: ${host}\r\nConnection: close\r\n\r\n", // "variables": {"${host}": ["example.com"]} // } let ipv4_addresses: Ipv4AddressRanges = self.ipv4_addresses.clone().into(); let ports: PortRanges = self.ports.clone().into(); let snis = self.snis.clone(); let template: Template = self.template.clone(); let variables: TemplateVariables = self .variables .clone() .into_iter() .filter(|(_, v)| !v.is_empty()) .collect(); let scanner = Http1Scanner::new(config.clone()); let progress_handler = Http1ProgressHandlerWebSocket::new( Arc::clone(&websocket), ipv4_addresses.len(), ports.len(), &variables, ); let _results = scanner.run_with_handle( ipv4_addresses, ports, snis, template, variables, progress_handler, ); } } ``` ### `services/mrpf_scanner_api/src/progress/https.rs` **Absolute path:** `/root/mrpf/services/mrpf_scanner_api/src/progress/https.rs` ```rust use mrpf_engine::ProgressHandler; use mrpf_http1_scanner::{Http1Message, TemplateVariables}; use mrpf_http1_scanner::{Http1Request, Http1Response}; use std::net::TcpStream; use std::sync::mpsc::Receiver; use std::sync::{Arc, Mutex}; use tungstenite::Message; use tungstenite::WebSocket; /// CLI ProgressHandler for the Http1 Scanner. /// /// The handler will ensure the scanner is stopped and the results are collected. A progress /// bar will be displayed in the CLI to show the progress of the scan. #[derive(Debug)] pub struct Http1ProgressHandlerWebSocket { total_targets: u64, websocket: Arc>>, } impl Http1ProgressHandlerWebSocket { pub fn new( websocket: Arc>>, ip_address_count: u32, port_count: u16, variables: &TemplateVariables, ) -> Self { let variable_count: usize = variables.values().map(|values| values.len()).product(); let total_targets = (ip_address_count * port_count as u32) as u64 * variable_count as u64; Self { total_targets, websocket, } } } impl ProgressHandler for Http1ProgressHandlerWebSocket { type Message = Http1Message; type Results = Vec<(Http1Request, Http1Response)>; fn handler(&mut self, receiver: Receiver) -> Self::Results { // Grab the websocket, we can safely lock it since we are the only ones using it // TODO: What happens if the websocket is closed by the receiving end? // probably just get an error here which is fine? let mut websocket = self.websocket.lock().unwrap(); // Variables holding statistics and results let mut results: Vec<(Http1Request, Http1Response)> = Vec::with_capacity(16); let mut total_packets_received: u64 = 0; let mut total_rx_buffer_full: u64 = 0; let mut total_syn_packets_sent: u64 = 0; let mut total_packets_sent: u64 = 0; let mut total_tx_buffer_full: u64 = 0; let mut results_found = 0; let start_time = std::time::Instant::now(); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { eprintln!("Failed to receive status update: {}", e); break; } }; match message { Http1Message::Result((request, response)) => { results_found += 1; let result = serde_json::json!({ "result": { "request": request, "response": response } }); println!("Sending result: {}", result); if websocket.send(Message::text(result.to_string())).is_err() { eprintln!("Got an error"); break; } results.push((request, response)); } Http1Message::RxStats { packet_count, full_read_buffer_count, } => { total_packets_received += packet_count as u64; total_rx_buffer_full += full_read_buffer_count as u64; } Http1Message::TxStats { syn_packet_count, packet_count, full_write_buffer_count, } => { total_syn_packets_sent += syn_packet_count as u64; total_packets_sent += packet_count as u64; total_tx_buffer_full += full_write_buffer_count as u64; } Http1Message::Finished => { break; } } let stats = serde_json::json!({ "stats": { "total_targets": self.total_targets, "total_remaining": self.total_targets - total_syn_packets_sent, "total_results": results_found, "total_packets_received": total_packets_received, "total_rx_buffer_full": total_rx_buffer_full, "total_syn_packets_sent": total_syn_packets_sent, "total_packets_sent": total_packets_sent, "total_tx_buffer_full": total_tx_buffer_full, } }); println!("Sending result: {}", stats); if websocket.send(Message::text(stats.to_string())).is_err() { eprintln!("Got an error"); break; } } let end_time = std::time::Instant::now(); let duration = (end_time - start_time).as_secs(); if websocket .send(Message::text( serde_json::json!({ "finished": { "duration": duration } }) .to_string(), )) .is_err() { eprintln!("Got an error"); } results } } ``` ### `services/mrpf_scanner_api/src/progress/mod.rs` **Absolute path:** `/root/mrpf/services/mrpf_scanner_api/src/progress/mod.rs` ```rust mod https; mod tcpsyn; mod tls; pub(crate) use https::Http1ProgressHandlerWebSocket; pub(crate) use tcpsyn::TcpSynProgressHandlerWebSocket; pub(crate) use tls::TlsProgressHandlerWebSocket; ``` ### `services/mrpf_scanner_api/src/progress/tcpsyn.rs` **Absolute path:** `/root/mrpf/services/mrpf_scanner_api/src/progress/tcpsyn.rs` ```rust use mrpf_engine::ProgressHandler; use mrpf_tcp_syn_scanner::TcpSynMessage; use std::collections::{HashMap, HashSet}; use std::net::{Ipv4Addr, TcpStream}; use std::sync::mpsc::Receiver; use std::sync::{Arc, Mutex}; use tungstenite::Message; use tungstenite::WebSocket; /// CLI ProgressHandler for the TCP SYN Scanner. /// /// The handler will ensure the scanner is stopped and the results are collected. A progress /// bar will be displayed in the CLI to show the progress of the scan. #[derive(Debug)] pub struct TcpSynProgressHandlerWebSocket { total_targets: u64, websocket: Arc>>, } impl TcpSynProgressHandlerWebSocket { pub fn new( websocket: Arc>>, ip_address_count: u32, port_count: u16, ) -> Self { Self { total_targets: (ip_address_count * port_count as u32) as u64, websocket, } } } impl ProgressHandler for TcpSynProgressHandlerWebSocket { type Message = TcpSynMessage; type Results = HashMap>; fn handler(&mut self, receiver: Receiver) -> Self::Results { // Grab the websocket, we can safely lock it since we are the only ones using it // TODO: What happens if the websocket is closed by the receiving end? // probably just get an error here which is fine? let mut websocket = self.websocket.lock().unwrap(); // Variables holding statistics let mut results: HashMap> = HashMap::with_capacity(32); let mut total_packets_received: u64 = 0; let mut total_rx_buffer_full: u64 = 0; let mut total_syn_packets_sent: u64 = 0; let mut total_packets_sent: u64 = 0; let mut total_tx_buffer_full: u64 = 0; let mut open_ports_found = 0; let start_time = std::time::Instant::now(); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { eprintln!("Failed to receive status update: {}", e); break; } }; match message { TcpSynMessage::Result((ip, port)) => { let ip_entry = results.entry(ip).or_default(); // Avoid duplicates if ip_entry.insert(port) { open_ports_found += 1; if websocket .send(Message::text( serde_json::json!({ "result": { "ip": ip, "port": port } }) .to_string(), )) .is_err() { eprintln!("Got an error"); break; } } } TcpSynMessage::RxStats { packet_count, full_read_buffer_count, } => { total_packets_received += packet_count as u64; total_rx_buffer_full += full_read_buffer_count as u64; } TcpSynMessage::TxStats { syn_packet_count, packet_count, full_write_buffer_count, } => { total_syn_packets_sent += syn_packet_count as u64; total_packets_sent += packet_count as u64; total_tx_buffer_full += full_write_buffer_count as u64; } TcpSynMessage::Finished => { break; } } if websocket .send(Message::text( serde_json::json!({ "stats": { "total_targets": self.total_targets, "total_remaining": self.total_targets - total_syn_packets_sent, "open_ports": open_ports_found, "total_packets_received": total_packets_received, "total_rx_buffer_full": total_rx_buffer_full, "total_syn_packets_sent": total_syn_packets_sent, "total_packets_sent": total_packets_sent, "total_tx_buffer_full": total_tx_buffer_full, } }) .to_string(), )) .is_err() { eprintln!("Got an error"); break; } } let end_time = std::time::Instant::now(); let duration = (end_time - start_time).as_secs(); if websocket .send(Message::text( serde_json::json!({ "finished": { "duration": duration } }) .to_string(), )) .is_err() { eprintln!("Got an error"); } results } } ``` ### `services/mrpf_scanner_api/src/progress/tls.rs` **Absolute path:** `/root/mrpf/services/mrpf_scanner_api/src/progress/tls.rs` ```rust use mrpf_engine::ProgressHandler; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use mrpf_tls_scanner::TlsMessage; use std::collections::{HashMap, HashSet}; use std::net::{Ipv4Addr, TcpStream}; use std::sync::mpsc::Receiver; use std::sync::{Arc, Mutex}; use tungstenite::Message; use tungstenite::WebSocket; /// CLI ProgressHandler for the TLS Scanner /// /// The handler will ensure the scanner is stopped and the results are collected. A progress /// bar will be displayed in the CLI to show the progress of the scan. #[derive(Debug)] pub struct TlsProgressHandlerWebSocket { total_targets: u64, websocket: Arc>>, } impl TlsProgressHandlerWebSocket { /// Initialize a new TlsProgress Handler for WebSocket pub fn new( websocket: Arc>>, targets: &[(Vec, Ipv4AddressRanges, PortRanges)], ) -> Self { let total_targets = targets .iter() .map(|(names, ips, ports)| names.len() as u64 * ips.len() as u64 * ports.len() as u64) .sum(); Self { total_targets, websocket, } } } impl ProgressHandler for TlsProgressHandlerWebSocket { type Message = TlsMessage; type Results = HashMap>; fn handler(&mut self, receiver: Receiver) -> Self::Results { // Grab the websocket, we can safely lock it since we are the only ones using it // TODO: What happens if the websocket is closed by the receiving end? // probably just get an error here which is fine? let mut websocket = self.websocket.lock().unwrap(); let mut results = HashMap::>::with_capacity(16); // Variables holding statistics let mut total_packets_received: u64 = 0; let mut total_rx_buffer_full: u64 = 0; let mut total_syn_packets_sent: u64 = 0; let mut total_packets_sent: u64 = 0; let mut total_tx_buffer_full: u64 = 0; let mut names_found = 0; let start_time = std::time::Instant::now(); loop { let message = match receiver.recv() { Ok(message) => message, Err(e) => { eprintln!("Failed to receive status update: {}", e); break; } }; match message { TlsMessage::Result((names, ip, port)) => { for name in names { let entry = results.entry(name.clone()).or_default(); let new_ip_port = entry.insert((ip, port)); if entry.len() == 1 { names_found += 1; } if new_ip_port && websocket .send(Message::text( serde_json::json!({ "result": { "name": name, "ip": ip.to_string(), "port": port } }) .to_string(), )) .is_err() { eprintln!("Got an error"); break; } } } TlsMessage::RxStats { packet_count, full_read_buffer_count, } => { total_packets_received += packet_count as u64; total_rx_buffer_full += full_read_buffer_count as u64; } TlsMessage::TxStats { syn_packet_count, packet_count, full_write_buffer_count, } => { total_syn_packets_sent += syn_packet_count as u64; total_packets_sent += packet_count as u64; total_tx_buffer_full += full_write_buffer_count as u64; } TlsMessage::Finished => { break; } } if websocket .send(Message::text( serde_json::json!({ "stats": { "total_targets": self.total_targets, "total_remaining": self.total_targets - total_syn_packets_sent, "names_found": names_found, "total_packets_received": total_packets_received, "total_rx_buffer_full": total_rx_buffer_full, "total_syn_packets_sent": total_syn_packets_sent, "total_packets_sent": total_packets_sent, "total_tx_buffer_full": total_tx_buffer_full, } }) .to_string(), )) .is_err() { eprintln!("Got an error"); break; } } let end_time = std::time::Instant::now(); let duration = (end_time - start_time).as_secs(); if websocket .send(Message::text( serde_json::json!({ "finished": { "duration": duration } }) .to_string(), )) .is_err() { eprintln!("Got an error"); } results } } ``` ### `services/mrpf_task_worker/src/handler.rs` **Absolute path:** `/root/mrpf/services/mrpf_task_worker/src/handler.rs` ```rust use mrpf_core::Result; use mrpf_core::tasks::ExecutionContext; use mrpf_core::tasks::models::{ Task, TaskContainerId, TaskManagerQueueMessage, TaskResult, TaskState, }; use tracing::{info, warn}; /// Main handler function for handling incoming Task Worker events /// /// # Arguments /// * `tasks` - Vec<(TaskContainerId, Task)> - A vector of tasks to be processed /// * `context` - &ExecutionContext - The worker context containing database and queue clients pub async fn handle_tasks( tasks: Vec<(TaskContainerId, Task)>, context: &ExecutionContext, ) -> Result> { let mut messages = Vec::new(); for (root_container_id, task) in tasks { let mut task = task; let task_id = *task.id(); info!(task_id = %task_id, task_type = %task.definition(), "Handling StartTask"); messages.push(process_task_result( &task_id, task.execute(context).await, root_container_id, )); } Ok(messages) } /// Helper to convert a task execution result into a TaskManagerQueueMessage /// /// It might be tempting to remove TaskResult completely and just return TaskManagerQueueMessage /// from our task templates, but this will tie those into the task manager implementation because /// we need the task ID and container id for the CompleteTask message. fn process_task_result( task_id: &TaskContainerId, result: Result, root_container_id: TaskContainerId, ) -> TaskManagerQueueMessage { match result { Ok(result) => { let (state, data, expected_task_count) = match result { TaskResult::CompletedWithData(state, data) => { info!(task_id = %task_id, "Task completed with data"); (state, Some(data), None) } TaskResult::Completed(state) => { info!(task_id = %task_id, "Task completed"); (state, None, None) } TaskResult::GeneratedTasks(expected_task_count) => ( TaskState::PendingTaskCreation, None, Some(expected_task_count), ), }; TaskManagerQueueMessage::CompleteTask { task_id: *task_id, state, root_container_id, expected_task_count, data, } } Err(e) => { warn!(error=%e, task_id = %task_id, "Failed to execute task"); TaskManagerQueueMessage::CompleteTask { task_id: *task_id, state: TaskState::Failed, root_container_id, expected_task_count: None, data: None, } } } } ``` ### `services/mrpf_task_worker/src/lib.rs` **Absolute path:** `/root/mrpf/services/mrpf_task_worker/src/lib.rs` ```rust mod handler; pub use handler::handle_tasks; ``` ### `shared/lambda_helpers/src/db_details.rs` **Absolute path:** `/root/mrpf/shared/lambda_helpers/src/db_details.rs` ```rust use lambda_runtime::Error; use std::env; use tracing::debug; /// Retrieve the database connection details by fetching the secret from AWS Secrets Manager. pub async fn get_db_connection_details( config: &aws_config::SdkConfig, ) -> Result<(String, u16, String, String, String), Error> { let proxy_endpoint = env::var("PROXY_ENDPOINT")?; debug!(proxy_endpoint = %proxy_endpoint, "Retrieved RDS Proxy endpoint from environment"); let secret_name = env::var("DB_SECRET")?; debug!(secret_name = %secret_name, "Retrieving RDS password from secretsmanager"); // Fetch the secret value let client = aws_sdk_secretsmanager::Client::new(config); debug!("Created Secrets Manager client"); let secret_value = client .get_secret_value() .secret_id(secret_name) .send() .await?; debug!("Successfully retrieved secret string"); let secret_string = secret_value .secret_string() .ok_or_else(|| Error::from("Secret is not a string"))?; let creds: serde_json::Value = serde_json::from_str(secret_string)?; let username = creds["username"] .as_str() .ok_or_else(|| Error::from("Missing username"))? .to_owned(); let password = creds["password"] .as_str() .ok_or_else(|| Error::from("Missing password"))? .to_owned(); let database_name: String = "postgres".to_string(); debug!(proxy_endpoint=%proxy_endpoint, db_name=%database_name, db_username=%username, "Retrieved database connection info"); Ok((proxy_endpoint, 5432, database_name, username, password)) } ``` ### `shared/lambda_helpers/src/lib.rs` **Absolute path:** `/root/mrpf/shared/lambda_helpers/src/lib.rs` ```rust mod db_details; mod path_params; mod query_params; mod tracing; pub use db_details::get_db_connection_details; pub use path_params::{get_path_string, get_path_uuid}; pub use query_params::{ get_multi_strings, get_multi_u16s, get_multi_uuids, get_opt_bool, get_opt_i64, get_opt_string, get_opt_u16, get_opt_uuid, }; pub use tracing::init_tracing; ``` ### `shared/lambda_helpers/src/path_params.rs` **Absolute path:** `/root/mrpf/shared/lambda_helpers/src/path_params.rs` ```rust use std::collections::HashMap; use std::hash::RandomState; use uuid::Uuid; /// Return a UUID for `key` from the PathMap (first occurrence) pub fn get_path_uuid( path_parameters: &HashMap, key: &str, ) -> Result { let id_str = path_parameters .get(key) .cloned() .ok_or_else(|| format!("{}: missing path parameter", key))?; let id = Uuid::parse_str(&id_str).map_err(|_| format!("{}: invalid uuid", key))?; Ok(id) } /// Return a String for `key` from the PathMap (first occurrence) pub fn get_path_string( path_parameters: &HashMap, key: &str, ) -> Result { let value = path_parameters .get(key) .cloned() .ok_or_else(|| format!("{}: missing path parameter", key))?; Ok(value) } ``` ### `shared/lambda_helpers/src/query_params.rs` **Absolute path:** `/root/mrpf/shared/lambda_helpers/src/query_params.rs` ```rust use query_map::QueryMap; use std::num::ParseIntError; use uuid::Uuid; /// Return an optional string for `key` from the QueryMap (first occurrence) pub fn get_opt_string(map: &QueryMap, key: &str) -> Option { map.first(key).map(|s| s.to_string()) } /// Return an optional bool for `key` from the QueryMap (first occurrence) pub fn get_opt_bool(map: &QueryMap, key: &str) -> Result, String> { match map.first(key) { None => Ok(None), Some(s) => s .parse::() .map(Some) .map_err(|e| format!("{}: invalid bool: {}", key, e)), } } /// Return an optional u16 for `key` from the QueryMap (first occurrence) pub fn get_opt_u16(map: &QueryMap, key: &str) -> Result, String> { match map.first(key) { None => Ok(None), Some(s) => s .parse::() .map(Some) .map_err(|e: ParseIntError| format!("{}: invalid u16: {}", key, e)), } } /// Return an optional i64 for `key` from the QueryMap (first occurrence) pub fn get_opt_i64(map: &QueryMap, key: &str) -> Result, String> { match map.first(key) { None => Ok(None), Some(s) => s .parse::() .map(Some) .map_err(|e: ParseIntError| format!("{}: invalid i64: {}", key, e)), } } /// Return an optional UUID for `key` from the QueryMap (first occurrence) pub fn get_opt_uuid(map: &QueryMap, key: &str) -> Result, String> { match map.first(key) { None => Ok(None), Some(s) => Uuid::parse_str(s) .map(Some) .map_err(|e| format!("{}: invalid uuid: {}", key, e)), } } /// Return all occurrences of `key` as Vec (multi-value params) pub fn get_multi_strings(map: &QueryMap, key: &str) -> Vec { map.all(key) .unwrap_or_default() .into_iter() .map(|s| s.to_string()) .collect() } /// Return all occurrences of `key` as Vec pub fn get_multi_u16s(map: &QueryMap, key: &str) -> Result, String> { let vals = map.all(key).unwrap_or_default(); let mut out = Vec::with_capacity(vals.len()); for s in vals { let n = s .parse::() .map_err(|e: ParseIntError| format!("{}: invalid u16: {}", key, e))?; out.push(n); } Ok(out) } /// Return all occurrences of `key` parsed as UUIDs pub fn get_multi_uuids(map: &QueryMap, key: &str) -> Result, String> { let vals = map.all(key).unwrap_or_default(); let mut out = Vec::with_capacity(vals.len()); for s in vals { let id = Uuid::parse_str(s).map_err(|e| format!("{}: invalid uuid: {}", key, e))?; out.push(id); } Ok(out) } ``` ### `shared/lambda_helpers/src/tracing.rs` **Absolute path:** `/root/mrpf/shared/lambda_helpers/src/tracing.rs` ```rust use tracing_subscriber::EnvFilter; use tracing_subscriber::filter::LevelFilter; /// Initialize tracing for logging in AWS Lambda environment. pub fn init_tracing() { // Determine log level from environment variable RUST_LOG. let env_filter = EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy(); // Initialize tracing subscriber for logging tracing_subscriber::fmt() .json() .with_env_filter(env_filter) // this needs to be set to remove duplicated information in the log. .with_current_span(false) // this needs to be set to false, otherwise ANSI color codes will // show up in a confusing manner in CloudWatch logs. .with_ansi(false) // disabling time is handy because CloudWatch will add the ingestion time. .without_time() // remove the name of the function from every log entry .with_target(false) .init(); } ``` ### `shared/mrpf_cert_parser/src/error.rs` **Absolute path:** `/root/mrpf/shared/mrpf_cert_parser/src/error.rs` ```rust // For the Result alias, don't do: // ```type Result = std::result::Result;``` // But do: // ```type Result = std::result::Result;``` // // Explanation [here](https://users.rust-lang.org/t/dont-do-type-result-t-std-result-t-myerror/49421) // // A further improvement is to add an empty tuple to the result T, this again saves a few characters. // It might not be super obvious to people working with my code, but this is my preference so acceptable // for this private codebase. /// Result type alias for mrpf_core library /// /// - The result type defaults to an empty tuple /// - The error type is an explicit enum with conversion methods `crate::Error`. /// /// Error types: /// * `Database` - Database errors /// * `Other` - Other errors pub type Result = std::result::Result; #[derive(Debug)] pub enum Error { InvalidData, InvalidTag, InvalidLength, CommonNameNotFound, InvalidString, IncompleteData, } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) } } impl std::error::Error for Error {} ``` ### `shared/mrpf_cert_parser/src/fqdn.rs` **Absolute path:** `/root/mrpf/shared/mrpf_cert_parser/src/fqdn.rs` ```rust /// Cleanup the FQDN by removing: /// - leading ' ', '.', '*', or '?' /// - trailing ' ' and '.' /// /// Keeps internal dots and wildcard chars that are not at the start. /// Leaves case untouched (caller can lowercase if desired). pub fn clean_fqdn(mut fqdn: String) -> String { // Trim leading chars let lead = fqdn .chars() .take_while(|c| matches!(c, '\n' | '\t' | '\r' | ' ' | '.' | '*' | '?')) .count(); if lead > 0 { fqdn.drain(..lead); if fqdn.is_empty() { return fqdn; } } // Trim trailing chars let trail = fqdn .chars() .rev() .take_while(|c| matches!(c, '\n' | '\t' | '\r' | ' ' | '.')) .count(); if trail > 0 { let new_len = fqdn.len() - trail; fqdn.truncate(new_len); } fqdn } ``` ### `shared/mrpf_cert_parser/src/lib.rs` **Absolute path:** `/root/mrpf/shared/mrpf_cert_parser/src/lib.rs` ```rust mod error; mod fqdn; mod parser; pub mod tags; // Exposed for TLS record parsing validation and external ASN.1 tag usage pub mod tlv; // Expose TLV helpers for other crates (e.g., network engine) pub use error::{Error, Result}; pub use fqdn::clean_fqdn; pub use parser::{extract_cn_and_san, extract_leaf_certificate}; pub use tlv::{parse_tlv, skip_tlv}; ``` ### `shared/mrpf_cert_parser/src/parser.rs` **Absolute path:** `/root/mrpf/shared/mrpf_cert_parser/src/parser.rs` ```rust use super::tlv::{parse_tlv, skip_tlv}; use crate::tags::*; use crate::{Error, Result}; use tracing::error; // OID constants (DER encoded value only, excluding tag/length) const OID_CN: &[u8] = b"\x55\x04\x03"; // 2.5.4.3 (CN) const OID_SAN: &[u8] = b"\x55\x1D\x11"; // 2.5.29.17 (SAN) /// Extract common name from given certificate subject fn extract_cn(subject: &[u8]) -> Result { let mut data = subject; while !data.is_empty() { let (tag, _, value, rest) = parse_tlv(data)?; if tag != TAG_SET { return Err(Error::InvalidTag); } let (seq_tag, _, seq_value, _seq_rest) = parse_tlv(value)?; if seq_tag != TAG_SEQUENCE { // SEQUENCE (AttributeTypeAndValue) return Err(Error::InvalidTag); } let (oid_tag, _, oid_value, oid_rest) = parse_tlv(seq_value)?; if oid_tag != TAG_OBJECT_IDENTIFIER { data = rest; // Advance past entire SET continue; } if oid_value == OID_CN { // 2.5.4.3 (CN) let (_str_tag, _, str_value, _) = parse_tlv(oid_rest)?; return std::str::from_utf8(str_value) .map(|s| s.to_string()) .map_err(|_| Error::InvalidString); } data = rest; // Advance past entire SET for non-CN } Err(Error::CommonNameNotFound) } fn extract_san(extensions: &[u8]) -> Result> { let mut data = extensions; let mut san_names = Vec::new(); while !data.is_empty() { let (tag, _, value, rest) = parse_tlv(data)?; if tag != TAG_SEQUENCE { // Extension SEQUENCE error!(" InvalidTag extension SEQUENCE: {:02X}", tag); return Err(Error::InvalidTag); } let (oid_tag, _, oid_value, mut inner_rest) = parse_tlv(value)?; // debug!(" OID Tag: {:02X}, OID: {:02X?}", oid_tag, oid_value); if oid_tag != TAG_OBJECT_IDENTIFIER { // OBJECT IDENTIFIER data = rest; continue; } if oid_value == OID_SAN { // debug!(" Found SAN OID"); loop { let (next_tag, next_length, next_value, next_rest) = parse_tlv(inner_rest)?; // debug!( // " Next Tag: {:02X}, Length: {}, Value start: {:02X?}", // next_tag, // next_length, // &next_value[..std::cmp::min(next_value.len(), 8)] // ); if next_tag == 0x01 { // BOOLEAN (criticality) // debug!(" Found criticality:"); if next_length != 1 || next_value.len() != 1 { error!(" Invalid criticality BOOLEAN length: {}", next_length); return Err(Error::InvalidData); } // debug!(" Criticality: {:02X}", next_value[0]); inner_rest = next_rest; // Advance past BOOLEAN (3 bytes total) continue; // Restart loop to parse next TLV (OCTET STRING) } else if next_tag == TAG_OCTET_STRING { let (seq_tag, _, seq_content, _) = parse_tlv(next_value)?; // debug!( // " Octet Tag: {:02X}, Sequence Tag: {:02X}", // next_tag, seq_tag // ); if seq_tag != TAG_SEQUENCE { // SEQUENCE of GeneralNames error!(" InvalidTag GeneralNames SEQUENCE: {:02X}", seq_tag); return Err(Error::InvalidTag); } san_names.extend(extract_dns_names(seq_content)?); return Ok(san_names); } else { error!( " InvalidTag expecting BOOLEAN or OCTET STRING: {:02X}", next_tag ); return Err(Error::InvalidTag); } } } data = rest; } Ok(san_names) } // Helper function to extract dNSNames from GeneralNames SEQUENCE fn extract_dns_names(seq_content: &[u8]) -> Result> { let mut san_data = seq_content; let mut san_names = Vec::new(); while !san_data.is_empty() { let (san_tag, _, san_value, rest_san) = parse_tlv(san_data)?; // debug!( // " SAN Tag: {:02X}, Value start: {:02X?}", // san_tag, // &san_value[..std::cmp::min(san_value.len(), 8)] // ); // Other tags I've seen, but not handling yet: // 0x81: rfc822Name (email) // 0x82: IA5String (dNSName) // 0x87: IP address // 0x86: URI // The full list of GeneralName types in SAN (per RFC 5280, Section 4.2.1.6) is: // otherName (0xA0, with inner OID and value) // rfc822Name (0x81 or sometimes 0x82) This is the email address of the issuer // dNSName (0x82) // x400Address (0xA3, rare, legacy X.400 protocol) // directoryName (0xA4 or 0xA0 depending on tools, typically a SEQUENCE) // ediPartyName (0xA5, extremely rare) // uniformResourceIdentifier (0x86) // iPAddress (0x87) // registeredID (0x88 or 0x83) if san_tag == TAG_DNS_NAME { // IA5String (dNSName) let name = std::str::from_utf8(san_value) .map(|s| s.to_string()) .map_err(|_| Error::InvalidString)?; san_names.push(name); } san_data = rest_san; } Ok(san_names) } /// Extract the Common name and Subject Alternative Names from a certificate pub fn extract_cn_and_san(cert: Vec) -> Result<(Option, Vec)> { // debug!("{:02X?}...", &cert); let (tag, _, value, _) = parse_tlv(&cert)?; // debug!("{:02X}: {:02X?}...", tag, &value[..4]); if tag != TAG_SEQUENCE { // Certificate SEQUENCE error!( "InvalidTag cert SEQUENCE: {:02X}: {:02X?}...", tag, &value[..4] ); return Err(Error::InvalidTag); } let (tbs_tag, _, mut data, _) = parse_tlv(value)?; // debug!("{:02X}: {:02X?}...", tbs_tag, &data[..4]); if tbs_tag != TAG_SEQUENCE { // tbsCertificate SEQUENCE error!( "InvalidTag tbsCertificate SEQUENCE: {:02X}: {:02X?}...", tbs_tag, &data[..4] ); return Err(Error::InvalidTag); } if !data.is_empty() && data[0] == 0xA0 { // Skip version data = skip_tlv(data)?; } // serialNumber, signature, issuer, validity data = skip_tlv(data)?; // serialNumber data = skip_tlv(data)?; // signature data = skip_tlv(data)?; // issuer data = skip_tlv(data)?; // validity // Parse subject let (subject_tag, _, subject_value, rest) = parse_tlv(data)?; if subject_tag != 0x30 { return Err(Error::InvalidTag); } // Extract CN let cn = match extract_cn(subject_value) { Ok(cn) => Some(cn), Err(Error::CommonNameNotFound) | Err(Error::InvalidString) => None, Err(e) => { let cn_lossy = String::from_utf8_lossy(subject_value).to_string(); error!("Failed to extract cn from {:?}: {:?}", cn_lossy, e); None } }; data = skip_tlv(rest)?; // Skip subjectPublicKeyInfo // Iterate through remaining optional fields to find extensions // and extract SAN names let mut san_names = Vec::new(); while !data.is_empty() { let (tag, _, value, rest) = parse_tlv(data)?; if tag == 0xA3 { // Extensions let (ext_tag, _, ext_value, _) = parse_tlv(value)?; // debug!("{:02X?}...", &ext_value[..4]); if ext_tag != TAG_SEQUENCE { return Err(Error::InvalidTag); } // debug!("Extracting SAN names"); san_names = extract_san(ext_value)?; break; } else if tag == 0xA1 || tag == 0xA2 { // debug!("Found issuerUniqueID or subjectUniqueID"); // Skip issuerUniqueID or subjectUniqueID data = rest; } else { break; } } Ok((cn, san_names)) } /// Extracts the first certificate (sender's certificate) from a TLS 1.2 certificate chain. /// /// The input `certificate_chain` is expected to be the raw bytes of the certificates list /// from a TLS 1.2 Certificate handshake message, as returned by `extract_tls12_certificate_chain`. /// This function assumes the chain is well-formed and returns the first certificate's raw DER bytes. /// /// # Errors /// Returns `Err(TlsExtractError::Incomplete)` if the input is too short to contain a certificate. /// Returns `Err(TlsExtractError::Invalid)` if the length field is malformed or exceeds the input. /// /// # Panics /// This function does not panic under normal circumstances, assuming `BytesMut` operations succeed. pub fn extract_leaf_certificate(mut certificate_chain: Vec) -> Result> { // Need at least 3 bytes for the certificate length field if certificate_chain.len() < 3 { return Err(Error::IncompleteData); } // Parse 3-byte big-endian length (certificate length) let cert_length = ((certificate_chain[0] as usize) << 16) | ((certificate_chain[1] as usize) << 8) | (certificate_chain[2] as usize); // Bounds check: ensure the buffer contains the full certificate bytes let end = 3 + cert_length; if certificate_chain.len() < end { return Err(Error::IncompleteData); } // Remove trailing certificates and length prefix, leaving only the leaf certificate certificate_chain.drain(end..); certificate_chain.drain(0..3); Ok(certificate_chain) } ``` ### `shared/mrpf_cert_parser/src/tags.rs` **Absolute path:** `/root/mrpf/shared/mrpf_cert_parser/src/tags.rs` ```rust // Generic ASN.1 Tags used in MRPF certificate parsing pub const TAG_SET: u8 = 0x31; pub const TAG_SEQUENCE: u8 = 0x30; pub const TAG_OBJECT_IDENTIFIER: u8 = 0x06; // pub(crate) const TAG_UTF8_STRING: u8 = 0x0C; pub(crate) const TAG_OCTET_STRING: u8 = 0x04; /// Data specific tags // pub(crate) const TAG_RFC822_NAME: u8 = 0x81; // Email address pub const TAG_DNS_NAME: u8 = 0x82; // IA5String // pub(crate) const TAG_URI: u8 = 0x86; // pub(crate) const TAG_IP_ADDRESS: u8 = 0x87; ``` ### `shared/mrpf_cert_parser/src/tlv.rs` **Absolute path:** `/root/mrpf/shared/mrpf_cert_parser/src/tlv.rs` ```rust use crate::{Error, Result}; use tracing::error; /// Parse a Tag Length Value (TLV) structure from the beginning of the given data slice. /// /// # Arguments /// * `data` - A byte slice containing the TLV encoded data. /// /// Returns a tuple containing: /// - tag: u8 - The tag byte. /// - length: usize - The length of the value. /// - value: &[u8] - A slice containing the value bytes. /// - remaining: &[u8] - A slice containing the remaining bytes after the TLV #[inline(always)] pub fn parse_tlv(data: &[u8]) -> Result<(u8, usize, &[u8], &[u8])> { if data.len() < 2 { error!("InvalidData: data.len()={}", data.len()); return Err(Error::InvalidData); } let tag = data[0]; let length_byte = data[1]; let mut offset = 2; // After tag and length byte let mut length = length_byte as usize; // Handle DER long-form length if applicable if length_byte >= 0x80 { let len_bytes = (length_byte & 0x7F) as usize; if offset + len_bytes > data.len() { return Err(Error::InvalidLength); } length = 0; for &byte in &data[offset..offset + len_bytes] { length = (length << 8) + byte as usize; } offset += len_bytes; } if offset + length > data.len() { error!( "InvalidLength: offset={}, length={}, data.len()={}", offset, length, data.len() ); error!("Data: {:?}", &data[..std::cmp::min(data.len(), 16)]); return Err(Error::InvalidLength); } let value = &data[offset..offset + length]; let remaining = &data[offset + length..]; Ok((tag, length, value, remaining)) } /// Skip over a TLV structure in the given data slice. #[inline(always)] pub fn skip_tlv(data: &[u8]) -> Result<&[u8], Error> { let (_, _, _, remaining) = parse_tlv(data)?; Ok(remaining) } ``` ### `shared/mrpf_core/sql/initial.md` **Absolute path:** `/root/mrpf/shared/mrpf_core/sql/initial.md` ```markdown # Initial configuration for SQL We need to setup the postgresql users and their passwords here. 1. Deploy the AWS Secretsmanager entry and get it to generate the password automatically 2. Get the generated passwords from secretsmanager 3. run SQL script similar to the following for creating the required users and passwords. I typically keep this script locally in my repository called initial_config.sql and have it in my .gitignore so it does not get committed with passwords in it. ```sql -- Added explicit connection string here in the file to easily logon through cloudshell. -- psql "host=.eu-west-1.rds.amazonaws.com port=5432 dbname=postgres user=postgres password=" -- User for our Task Manager lambda DO $$ BEGIN IF NOT EXISTS ( SELECT FROM pg_catalog.pg_roles WHERE rolname = 'task_manager' ) THEN CREATE USER task_manager WITH PASSWORD ''; -- Ensure to fill in the password between quotes here END IF; END $$; -- User for our Task Scheduler lambda DO $$ BEGIN IF NOT EXISTS ( SELECT FROM pg_catalog.pg_roles WHERE rolname = 'task_scheduler' ) THEN CREATE USER task_scheduler WITH PASSWORD ''; -- Ensure to fill in the password between quotes here END IF; END $$; -- User for our Internal Worker lambda DO $$ BEGIN IF NOT EXISTS ( SELECT FROM pg_catalog.pg_roles WHERE rolname = 'internal_worker' ) THEN CREATE USER internal_worker WITH PASSWORD ''; -- Ensure to fill in the password between quotes here END IF; END $$; -- User for our API lambda DO $$ BEGIN IF NOT EXISTS ( SELECT FROM pg_catalog.pg_roles WHERE rolname = 'api' ) THEN CREATE USER api WITH PASSWORD ''; -- Ensure to fill in the password between quotes here END IF; END $$; ``` ``` ### `shared/mrpf_core/src/db/client.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/client.rs` ```rust use crate::Result; use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode}; use std::time::Duration; pub struct DbClient { pub(crate) pool: sqlx::PgPool, } impl DbClient { /// Initialize a new database client. /// /// # Arguments /// * `host` - The database host /// * `port` - The database port /// * `database_name` - The name of the database /// * `username` - The database username /// * `password` - The database password /// * `search_path` - The PostgreSQL search_path to set for the session. task_manager pub async fn new( host: &str, port: u16, database_name: &str, username: &str, password: &str, app_name: &str, search_path: &str, ) -> Result { // Compose the connection options let connect_options = PgConnectOptions::new() .host(host) .port(port) .database(database_name) .username(username) .password(password) .application_name(app_name) .ssl_mode(PgSslMode::Require); // Create owned Strings for values that will be moved into the 'after_connect' closure. // This avoids borrowing local &str references and satisfies the 'static requirement. let search_path_for_conn = search_path.to_string(); let pool = PgPoolOptions::new() .max_connections(100) // Align this with RDS Proxy settings .min_connections(5) .acquire_timeout(Duration::from_secs(30)) .idle_timeout(Duration::from_secs(600)) .max_lifetime(Duration::from_secs(1800)) .after_connect(move |conn, _| { let search_path = search_path_for_conn.clone(); Box::pin(async move { sqlx::query("SELECT set_config('search_path', $1, false)") .bind(search_path.as_str()) .execute(&mut *conn) .await?; Ok(()) }) }) .connect_with(connect_options) .await?; Ok(Self { pool }) } pub fn pool(&self) -> &sqlx::PgPool { &self.pool } } ``` ### `shared/mrpf_core/src/db/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/mod.rs` ```rust mod client; mod recon; #[cfg(feature = "task_shared")] mod tasks; pub use client::DbClient; ``` ### `shared/mrpf_core/src/db/recon/apn_device.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/recon/apn_device.rs` ```rust use crate::DbClient; use crate::recon::{ApnDevice, ApnDeviceEnv, ApnDeviceKind}; use crate::{Error, Result}; impl DbClient { /// Fetch a single APN device by id. pub async fn get_apn_device(&self, id: &str) -> Result { let device: ApnDevice = sqlx::query_as( r#"SELECT id, name, kind, env, created_at FROM recon.apn_devices WHERE id = $1"#, ) .bind(id) .fetch_one(self.pool()) .await?; Ok(device) } /// List APN devices, optionally filtering by ids, kind, env. /// If `device_ids` is empty all devices are considered. /// Optional filters: if None they are ignored. pub async fn list_apn_devices( &self, device_ids: Vec, kind: Option, env: Option, ) -> Result> { let devices: Vec = sqlx::query_as( r#" SELECT id, name, kind, env, created_at FROM recon.apn_devices WHERE (cardinality($1::text[]) = 0 OR id = ANY($1)) AND ($2::apn_device_kind IS NULL OR kind = $2) AND ($3::apn_device_env IS NULL OR env = $3) ORDER BY created_at DESC "#, ) .bind(&device_ids) .bind(kind) .bind(env) .fetch_all(self.pool()) .await?; Ok(devices) } /// Upsert a single APN device. pub async fn put_apn_device(&self, device: &ApnDevice) -> Result { sqlx::query( r#" INSERT INTO recon.apn_devices (id, name, kind, env, created_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET name = COALESCE(EXCLUDED.name, recon.apn_devices.name), kind = COALESCE(EXCLUDED.kind, recon.apn_devices.kind), env = COALESCE(EXCLUDED.env, recon.apn_devices.env) "#, ) .bind(&device.id) .bind(&device.name) .bind(device.kind) .bind(device.env) .bind(device.created_at) .execute(self.pool()) .await?; Ok(()) } /// Bulk upsert APN devices. pub async fn put_apn_devices(&self, devices: Vec) -> Result { if devices.is_empty() { return Err(Error::Other("No apn_devices provided to update".into())); } use serde_json::{Value as JsonValue, json}; let rows: Vec = devices .into_iter() .map(|d| { json!({ "id": d.id, "name": d.name, "kind": d.kind, "env": d.env, "created_at": d.created_at, }) }) .collect(); sqlx::query( r#" WITH input AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id text, name text, kind apn_device_kind, env apn_device_env, created_at timestamptz ) ) INSERT INTO recon.apn_devices (id, name, kind, env, created_at) SELECT id, COALESCE(name, ''), kind, env, COALESCE(created_at, now()) FROM input ON CONFLICT (id) DO UPDATE SET name = COALESCE(EXCLUDED.name, recon.apn_devices.name), kind = COALESCE(EXCLUDED.kind, recon.apn_devices.kind), env = COALESCE(EXCLUDED.env, recon.apn_devices.env) "#, ) .bind(JsonValue::Array(rows)) .execute(self.pool()) .await?; Ok(()) } /// Delete a single APN device. pub async fn delete_apn_device(&self, id: &str) -> Result { sqlx::query("DELETE FROM recon.apn_devices WHERE id = $1") .bind(id) .execute(self.pool()) .await?; Ok(()) } /// Delete multiple APN devices by id. pub async fn delete_apn_devices(&self, ids: Vec) -> Result { if ids.is_empty() { return Err(Error::Other("No apn_device ids provided to delete".into())); } sqlx::query("DELETE FROM recon.apn_devices WHERE id = ANY($1)") .bind(&ids) .execute(self.pool()) .await?; Ok(()) } } ``` ### `shared/mrpf_core/src/db/recon/domain.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/recon/domain.rs` ```rust use crate::DbClient; use crate::recon::{Domain, SoaRecord}; use crate::{Error, Result}; use chrono::{DateTime, Utc}; use serde_json::{Value as JsonValue, json}; use sqlx::Row; use std::collections::{HashMap, HashSet}; impl DbClient { /// Get a single domain by its FQDN. pub async fn get_domain(&self, fqdn: &str) -> Result { let row: Domain = sqlx::query_as( r#" SELECT f.id as id, f.target_id, f.is_wildcard, COALESCE(c.fqdns, ARRAY[]::text[]) AS cnames, -- Build enriched ip_addresses as JSONB with port metadata COALESCE( ( SELECT jsonb_agg( jsonb_build_object( 'ip', ip_txt, 'target_id', f.target_id, 'tcp_ports', COALESCE(ipinfo.tcp_ports, '{}'::int[]), 'udp_ports', COALESCE(ipinfo.udp_ports, '{}'::int[]) ) ) FROM UNNEST(COALESCE(a.ip_addresses, '{}'::text[])) AS ip_txt LEFT JOIN recon.ip_addresses ipinfo ON ipinfo.ip = ip_txt ), '[]'::jsonb ) AS ip_addresses FROM recon.fqdns f LEFT JOIN recon.cname_record_sets c ON c.fqdn_id = f.id AND c.is_current = true LEFT JOIN recon.a_record_sets a ON a.fqdn_id = f.id AND a.is_current = true WHERE f.id = $1 "#, ) .bind(fqdn) .fetch_one(self.pool()) .await?; Ok(row) } /// List domains for a target with optional filters. /// /// The returned cnames and ip_addresses are the latest known values. /// /// # Arguments /// /// * `target_id` - The ID of the target to list domains for. /// * `is_active` - Optional filter to include only active or inactive domains. /// * `is_wildcard` - Optional filter to include only wildcard or non-wildcard domains. pub async fn list_domains( &self, target_id: String, is_active: Option, is_wildcard: Option, is_dangling: Option, ) -> Result> { let domains = sqlx::query_as( r#" SELECT f.id as id, f.target_id, f.is_wildcard, COALESCE(c.fqdns, ARRAY[]::text[]) AS cnames, COALESCE( ( SELECT jsonb_agg( jsonb_build_object( 'ip', ip_txt, 'target_id', f.target_id, 'tcp_ports', COALESCE(ipinfo.tcp_ports, '{}'::int[]), 'udp_ports', COALESCE(ipinfo.udp_ports, '{}'::int[]) ) ) FROM UNNEST(COALESCE(a.ip_addresses, '{}'::text[])) AS ip_txt LEFT JOIN recon.ip_addresses ipinfo ON ipinfo.ip = ip_txt ), '[]'::jsonb ) AS ip_addresses FROM recon.fqdns f LEFT JOIN recon.cname_record_sets c ON c.fqdn_id = f.id AND c.is_current = true LEFT JOIN recon.a_record_sets a ON a.fqdn_id = f.id AND a.is_current = true WHERE f.target_id = $1 AND ($2::bool IS NULL OR f.is_active = $2) AND ($3::bool IS NULL OR f.is_wildcard = $3) AND ($4::bool IS NULL OR f.is_dangling = $4) "#, ) .bind(&target_id) .bind(is_active) .bind(is_wildcard) .bind(is_dangling) .fetch_all(self.pool()) .await? .into_iter() .collect(); Ok(domains) } /// List fqdns for a target with optional filters. /// /// # Arguments /// /// * `target_id` - The ID of the target to list domains for. /// * `is_active` - Optional filter to include only active or inactive domains. /// * `is_wildcard` - Optional filter to include only wildcard or non-wildcard domains. pub async fn list_fqdns( &self, target_id: String, is_active: Option, is_wildcard: Option, is_dangling: Option, ) -> Result> { // Simple filtered select; options are applied only when provided. let fqdns = sqlx::query_scalar( r#" SELECT id FROM recon.fqdns WHERE target_id = $1 AND ($2::bool IS NULL OR is_active = $2) AND ($3::bool IS NULL OR is_wildcard = $3) AND ($4::bool IS NULL OR is_dangling = $4) "#, ) .bind(&target_id) .bind(is_active) .bind(is_wildcard) .bind(is_dangling) .fetch_all(self.pool()) .await? .into_iter() .collect(); Ok(fqdns) } /// List all current SOA records for domains under a specific target. /// /// The returned set is deduplicated by owner_name. pub async fn list_soa_records(&self, target_id: &str) -> Result> { let soa_records = sqlx::query_as::<_, SoaRecord>( r#" SELECT s.owner_name, s.primary_ns_id AS primary_ns, s.hostmaster_id AS hostmaster FROM recon.soa_records s JOIN recon.fqdns fq ON fq.id = s.owner_name WHERE s.is_current = true AND fq.target_id = $1 "#, ) .bind(target_id) .fetch_all(self.pool()) .await? .into_iter() .collect(); Ok(soa_records) } /// Set the full SOA structure for all domains under a single target /// /// The given soa_records will be treated as the current authoritative (full truth) for all domains related /// to the target. /// /// - Existing SOA records for domains under this target that are not included /// in the provided soa_records will be marked as non-current. /// - New SOA records will be inserted as current. /// - Existing SOA records that match the provided records will be left unchanged. pub async fn set_target_soa_records( &self, target_id: &str, soa_records: HashSet, ) -> Result { if soa_records.is_empty() { return Err(Error::Other( "No SOA records provided to set for target".into(), )); } let mut input_rows: Vec = Vec::with_capacity(soa_records.len()); for rec in soa_records { input_rows.push(json!({ "owner_name": rec.owner_name, "primary_ns": rec.primary_ns, "hostmaster": rec.hostmaster })); } let json_array = JsonValue::Array(input_rows); // Single SQL statement performing ensure + retire + insert. // We ignore returned counts; could be logged if needed. sqlx::query( r#" WITH input AS ( SELECT * FROM jsonb_to_recordset($2::jsonb) AS t( owner_name text, primary_ns text, hostmaster text ) ), fqdns_to_ensure AS ( -- union of all fqdn strings we must ensure exist SELECT owner_name AS fqdn FROM input UNION SELECT primary_ns FROM input UNION SELECT hostmaster FROM input ), ensured_fqdns AS ( INSERT INTO recon.fqdns (id, target_id, path, is_wildcard, is_active, is_dangling) SELECT f.fqdn, $1, util.fqdn_to_ltree(f.fqdn), NULL, true, false FROM fqdns_to_ensure f ON CONFLICT (id) DO NOTHING RETURNING id ), current_soas AS ( SELECT s.id, s.owner_name, s.primary_ns_id, s.hostmaster_id FROM recon.soa_records s JOIN recon.fqdns fq ON fq.id = s.owner_name WHERE s.is_current = true AND fq.target_id = $1 ), retire AS ( UPDATE recon.soa_records s SET is_current = false WHERE s.id IN ( SELECT s.id FROM current_soas s LEFT JOIN input i ON i.owner_name = s.owner_name WHERE i.owner_name IS NULL OR EXISTS ( SELECT 1 FROM input i2 WHERE i2.owner_name = s.owner_name AND (i2.primary_ns <> s.primary_ns_id OR i2.hostmaster <> s.hostmaster_id) ) ) RETURNING owner_name ), ins AS ( INSERT INTO recon.soa_records (owner_name, primary_ns_id, hostmaster_id) SELECT i.owner_name, i.primary_ns, i.hostmaster FROM input i LEFT JOIN current_soas s ON s.owner_name = i.owner_name AND s.primary_ns_id = i.primary_ns AND s.hostmaster_id = i.hostmaster WHERE s.id IS NULL -- brand new OR changed (join fails when differing) RETURNING owner_name ) SELECT (SELECT count(*) FROM retire) AS retired, (SELECT count(*) FROM ins) AS inserted; "#, ) .bind(target_id) .bind(json_array) .execute(self.pool()) .await?; Ok(()) } /// Create or update multiple domains. /// /// This function takes a vector of `Domain` instances and creates or updates them in the database. /// /// NOTE: This operation WILL add cnames and a records but it WILL NOT update ip_addresses tcp or udp ports. /// We have to consider if we want to do this eventually but for now I think we're better off leaving /// that to a dedicated put_ipv4_addresses function. Note that our get_domain and list_domains function WILL /// return the full Ipv4Address objects with tcp/udp ports. /// /// - If a domain already exists, its normal attributes will be overwritten. /// - CNAME and IP history are updated accordingly. /// - Zones are created as needed. /// - tcp/udp ports in ip_addresses are NOT updated. pub async fn put_domains(&self, domains: HashSet) -> Result { if domains.is_empty() { return Err(Error::Other("No domains provided to update".into())); } // Prepare JSON arrays for bulk set-based operations. let mut fqdn_rows: Vec = Vec::with_capacity(domains.len()); let mut cname_rows: Vec = Vec::new(); let mut ip_rows: Vec = Vec::new(); // Use a single timestamp for all new cname/ip_address rows in this batch to allow // historical correlation. let created_at = Utc::now(); let created_at_str = created_at.to_rfc3339(); for domain in domains.into_iter() { fqdn_rows.push(json!({ "id": domain.fqdn, "target_id": domain.target_id, "is_wildcard": domain.is_wildcard, "is_active": domain.is_active(), "is_dangling": domain.is_dangling() // technologies: domain.technologies, })); // NOTE: Even if either CNAME or A records are empty, we still need to upsert // them to ensure we've got that linked to the current state. // a_records_id mapping will be handled during upsert // cname fqdns are stored as-provided. Entries are order-sensitive and // duplicates could technically occur. The DNS resolver should // ensure no unlimited recursion is stored in the chain. cname_rows.push(json!({ "fqdn_id": domain.fqdn, "fqdns": domain.cnames, "created_at": created_at_str })); // Normalize ip addresses to sorted Vec for consistent storage (store raw IP strings only) let mut ip_addresses: Vec = domain .ip_addresses .into_iter() .map(|addr| addr.ip.to_string()) .collect(); ip_addresses.sort_unstable(); ip_rows.push(json!({ "fqdn_id": domain.fqdn, "ip_addresses": ip_addresses, "created_at": created_at_str })); } let mut tx = self.pool().begin().await?; // Bulk upsert FQDN rows sqlx::query( r#" WITH input AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id text, target_id text, is_wildcard bool, is_active bool, is_dangling bool ) ) INSERT INTO recon.fqdns (id, target_id, path, is_wildcard, is_active, is_dangling) SELECT id, target_id, util.fqdn_to_ltree(id), is_wildcard, is_active, is_dangling FROM input ON CONFLICT (id) DO UPDATE SET is_wildcard = COALESCE(EXCLUDED.is_wildcard, recon.fqdns.is_wildcard), is_active = COALESCE(EXCLUDED.is_active, recon.fqdns.is_active), is_dangling = COALESCE(EXCLUDED.is_dangling, recon.fqdns.is_dangling) "#, ) .bind(JsonValue::Array(fqdn_rows)) .execute(&mut *tx) .await?; // ------------------------------------------------------------------ // Bulk A (address) record handling. // Insert a new current row ONLY if the (order-insensitive) set of IPs // differs from the existing current row OR if no current row exists. // // Existing differing // current rows are retired (is_current=false) before inserting. // Finally, return mapping fqdn_id -> current a_records.id for all // FQDNs provided in this batch (even if unchanged) so we can link // CNAME rows to the correct a_records_id. // ------------------------------------------------------------------ let mut a_records_map: HashMap = HashMap::new(); let a_rows = sqlx::query( r#" WITH input AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( fqdn_id text, ip_addresses text[], created_at timestamptz ) ), ensured AS ( -- ensure every fqdn appears even if no IPs were provided (already true) but we may still need to create an empty row SELECT fqdn_id, ip_addresses, created_at FROM input ), retire AS ( UPDATE recon.a_record_sets ar SET is_current = false FROM ensured i WHERE ar.fqdn_id = i.fqdn_id AND ar.is_current = true AND ( (SELECT array_agg(x ORDER BY x) FROM unnest(ar.ip_addresses) x) IS DISTINCT FROM (SELECT array_agg(x ORDER BY x) FROM unnest(i.ip_addresses) x) ) RETURNING ar.fqdn_id ), unchanged AS ( SELECT i.fqdn_id, ar.id FROM ensured i JOIN recon.a_record_sets ar ON ar.fqdn_id = i.fqdn_id AND ar.is_current = true WHERE i.fqdn_id NOT IN (SELECT fqdn_id FROM retire) ), ins AS ( INSERT INTO recon.a_record_sets (fqdn_id, ip_addresses, created_at) SELECT i.fqdn_id, i.ip_addresses, i.created_at FROM ensured i LEFT JOIN recon.a_record_sets cur ON cur.fqdn_id = i.fqdn_id AND cur.is_current = true WHERE cur.fqdn_id IS NULL OR i.fqdn_id IN (SELECT fqdn_id FROM retire) RETURNING id, fqdn_id ) SELECT fqdn_id, id FROM unchanged UNION ALL SELECT fqdn_id, id FROM ins; "#, ) .bind(JsonValue::Array(ip_rows)) .fetch_all(&mut *tx) .await?; for r in a_rows.into_iter() { let fqdn_id: String = r.get("fqdn_id"); let id: i64 = r.get("id"); a_records_map.insert(fqdn_id, id); } // ------------------------------------------------------------------ // Bulk CNAME handling (single current row per fqdn). Only insert a new // row when ordered array differs from current. We attach the matching // a_records_id (if any) from the map above. // ------------------------------------------------------------------ if !cname_rows.is_empty() { let mut cname_rows_with_a: Vec = Vec::with_capacity(cname_rows.len()); for row in cname_rows.into_iter() { if let Some(fqdn_id) = row.get("fqdn_id").and_then(|v| v.as_str()) { let a_id = a_records_map.get(fqdn_id).copied(); cname_rows_with_a.push(json!({ "fqdn_id": fqdn_id, "fqdns": row.get("fqdns").cloned().unwrap_or(JsonValue::Array(vec![])), "a_records_id": a_id, "created_at": created_at_str })); } } if !cname_rows_with_a.is_empty() { sqlx::query( r#" WITH input AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( fqdn_id text, fqdns text[], a_records_id bigint, created_at timestamptz ) ), retire AS ( UPDATE recon.cname_record_sets c SET is_current = false FROM input i WHERE c.fqdn_id = i.fqdn_id AND c.is_current = true AND c.fqdns IS DISTINCT FROM i.fqdns -- order-sensitive diff RETURNING c.fqdn_id ), ins AS ( INSERT INTO recon.cname_record_sets (fqdn_id, fqdns, a_records_id, created_at) SELECT i.fqdn_id, i.fqdns, i.a_records_id, i.created_at FROM input i LEFT JOIN recon.cname_record_sets cur ON cur.fqdn_id = i.fqdn_id AND cur.is_current = true WHERE cur.fqdn_id IS NULL OR i.fqdn_id IN (SELECT fqdn_id FROM retire) RETURNING id ) SELECT 1; "#, ) .bind(JsonValue::Array(cname_rows_with_a)) .execute(&mut *tx) .await?; } } tx.commit().await?; Ok(()) } /// Return a vector of historical (timestamp, cnames, ip_addresses) resolution chain for given `fqdn`. /// /// TODO: The timestamp of CNAME and A records may differ. We should consider /// how to best represent this in the returned history. Perhaps timestamp per record type? pub async fn list_fqdn_history( &self, fqdn: &str, ) -> Result, Vec, HashSet)>> { // We fetch all historical cname_record_sets for the FQDN along with their associated // (possibly nil) a_record_sets. If a current or historical CNAME row references an // a_records_id, we join to obtain its ip_addresses set. We then return ordered by timestamp DESC. let rows = sqlx::query( r#" SELECT c.created_at, c.fqdns, a.ip_addresses FROM recon.cname_record_sets c LEFT JOIN recon.a_record_sets a ON a.id = c.a_records_id WHERE c.fqdn_id = $1 ORDER BY c.created_at DESC "#, ) .bind(fqdn) .fetch_all(self.pool()) .await?; let mut history = Vec::with_capacity(rows.len()); for r in rows.into_iter() { let created_at: DateTime = r.get("created_at"); let fqdns: Vec = r.get::, _>("fqdns"); let ips_opt: Option> = r.try_get("ip_addresses").ok(); let ip_set: HashSet = ips_opt.unwrap_or_default().into_iter().collect(); history.push((created_at, fqdns, ip_set)); } Ok(history) } /// Search for FQDNs matching the PostgreSQL fqdn_glob_to_lquery function. /// /// If `deep_edges` is true, the search will also include /// FQDNs that are more than one level deeper than the query. /// /// # Examples /// /// search_fqdns("*.example.com", false) -> ["example.com", "www.example.com"] /// search_fqdns("www.*.com", false) -> ["www.example.com", "www.another.com"] /// search_fqdns("*.example.com", true) -> ["www.example.com", "sub.www.example.com"] pub async fn search_fqdns(&self, query: &str, deep_edges: bool) -> Result> { let rows: Vec = sqlx::query_scalar( r#" SELECT id FROM recon.fqdns WHERE path ~ util.fqdn_glob_to_lquery($1, $2) "#, ) .bind(query) .bind(deep_edges) .fetch_all(self.pool()) .await?; Ok(rows) } /// Delete a single domain by its FQDN. pub async fn delete_domain(&self, fqdn: &str) -> Result { sqlx::query( r#" DELETE FROM recon.fqdns WHERE id = $1 "#, ) .bind(fqdn) .execute(self.pool()) .await?; Ok(()) } /// Delete domains by their FQDNs. pub async fn delete_domains(&self, fqdns: Vec) -> Result { if fqdns.is_empty() { return Err(Error::Other("No fqdns provided to delete".into())); } sqlx::query( r#" DELETE FROM recon.fqdns WHERE id = ANY($1) "#, ) .bind(&fqdns) .execute(self.pool()) .await?; Ok(()) } } ``` ### `shared/mrpf_core/src/db/recon/ipv4address.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/recon/ipv4address.rs` ```rust use crate::DbClient; use crate::recon::Ipv4Address; use crate::{Error, Result}; use serde_json::{Value as JsonValue, json}; use sqlx::postgres::PgQueryResult; use std::collections::HashSet; impl DbClient { /// List all unique IP addresses associated with a target. pub async fn list_ip_addresses( &self, target_id: Option, ) -> Result> { let ip_addresses = if target_id.is_none() { // No target_ids provided, return all IPs. sqlx::query_as( r#" SELECT ip::text, target_id, tcp_ports, udp_ports FROM recon.ip_addresses "#, ) .fetch_all(self.pool()) .await? } else { sqlx::query_as( r#" SELECT ip::text, target_id, tcp_ports, udp_ports FROM recon.ip_addresses WHERE target_id = $1 "#, ) .bind(&target_id) .fetch_all(self.pool()) .await? } .into_iter() .collect(); Ok(ip_addresses) } /// Get a single IP address by its string representation. pub async fn get_ip_address(&self, ip: &str) -> Result { let row: Ipv4Address = sqlx::query_as( r#" SELECT ip::text, target_id, tcp_ports, udp_ports FROM recon.ip_addresses WHERE ip = $1 "#, ) .bind(ip) .fetch_one(self.pool()) .await?; Ok(row) } /// Insert or update multiple IP addresses. pub async fn put_ip_addresses(&self, ips: HashSet) -> Result { if ips.is_empty() { return Err(Error::Other("No ips provided to update".into())); } // Convert ips to JSON array and use jsonb_to_recordset to upsert multiple rows in one query. let mut json_arr: Vec = Vec::with_capacity(ips.len()); for ip in ips.into_iter() { let obj = json!({ "ip": ip.ip, "target_id": ip.target_id, "tcp_ports": ip.tcp_ports, "udp_ports": ip.udp_ports, }); json_arr.push(obj); } let jsonb = JsonValue::Array(json_arr); sqlx::query( r#" INSERT INTO recon.ip_addresses (ip, target_id, tcp_ports, udp_ports) SELECT x.ip, x.target_id, x.tcp_ports, x.udp_ports FROM jsonb_to_recordset($1::jsonb) AS x(ip text, target_id text, tcp_ports int[], udp_ports int[]) ON CONFLICT (ip) DO UPDATE SET target_id = EXCLUDED.target_id, tcp_ports = EXCLUDED.tcp_ports, udp_ports = EXCLUDED.udp_ports "#, ) .bind(&jsonb) .execute(self.pool()) .await?; Ok(()) } pub async fn update_ip_ports( &self, ip: &str, tcp_ports: Option>, udp_ports: Option>, ) -> Result { if tcp_ports.is_none() && udp_ports.is_none() { return Err(Error::Other("No ports provided to update".into())); } let result: Result = match (tcp_ports, udp_ports) { (Some(tcp), Some(udp)) => { let tcp_i: Vec = tcp.into_iter().map(|p| p as i32).collect(); let udp_i: Vec = udp.into_iter().map(|p| p as i32).collect(); sqlx::query( r#" UPDATE recon.ip_addresses SET tcp_ports = ( SELECT ARRAY(SELECT DISTINCT p FROM unnest(coalesce(tcp_ports,'{}') || $2::int[]) p ORDER BY 1) ), udp_ports = ( SELECT ARRAY(SELECT DISTINCT p FROM unnest(coalesce(udp_ports,'{}') || $3::int[]) p ORDER BY 1) ), WHERE ip = $1 "#, ) .bind(ip) .bind(&tcp_i) .bind(&udp_i) .execute(self.pool()) .await .map_err(crate::Error::from) } (Some(tcp), None) => { let tcp_i: Vec = tcp.into_iter().map(|p| p as i32).collect(); sqlx::query( r#" UPDATE recon.ip_addresses SET tcp_ports = ( SELECT ARRAY(SELECT DISTINCT p FROM unnest(coalesce(tcp_ports,'{}') || $2::int[]) p ORDER BY 1) ), WHERE ip = $1 "#, ) .bind(ip) .bind(&tcp_i) .execute(self.pool()) .await .map_err(crate::Error::from) } (None, Some(udp)) => { let udp_i: Vec = udp.into_iter().map(|p| p as i32).collect(); sqlx::query( r#" UPDATE recon.ip_addresses SET udp_ports = ( SELECT ARRAY(SELECT DISTINCT p FROM unnest(coalesce(udp_ports,'{}') || $2::int[]) p ORDER BY 1) ), WHERE ip = $1 "#, ) .bind(ip) .bind(&udp_i) .execute(self.pool()) .await .map_err(crate::Error::from) } (None, None) => unreachable!("This case is handled at the beginning of the function"), }; if result?.rows_affected() == 0 { return Err(Error::Other(format!("No ip found to update: {}", ip))); } Ok(()) } pub async fn add_tcp_ports(&self, ip: &str, ports: HashSet) -> Result { if ports.is_empty() { return Err(Error::Other("No ports provided to update".into())); } let v: Vec = ports.into_iter().collect(); self.update_ip_ports(ip, Some(v), None).await } pub async fn add_udp_ports(&self, ip: &str, ports: HashSet) -> Result { if ports.is_empty() { return Err(Error::Other("No ports provided to update".into())); } let v: Vec = ports.into_iter().collect(); self.update_ip_ports(ip, None, Some(v)).await } pub async fn delete_ip_address(&self, ip: &str) -> Result { sqlx::query("DELETE FROM recon.ip_addresses WHERE ip = $1") .bind(ip) .execute(self.pool()) .await?; Ok(()) } pub async fn delete_ip_addresses(&self, ips: Vec) -> Result { if ips.is_empty() { return Err(Error::Other("No ips provided to delete".into())); } sqlx::query("DELETE FROM recon.ip_addresses WHERE ip = ANY($1)") .bind(&ips) .execute(self.pool()) .await?; Ok(()) } } ``` ### `shared/mrpf_core/src/db/recon/jobs.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/recon/jobs.rs` ```rust use crate::DbClient; use crate::recon::{Job, JobState}; use crate::{Error, Result}; use chrono::{DateTime, Utc}; use serde_json::{Value as JsonValue, json}; use uuid::Uuid; impl DbClient { /// Fetch a job (by id) including latest run history (if any) in a single query. /// /// # Arguments /// * `id` - The ID of the TaskCollection to fetch. pub async fn get_job(&self, id: &Uuid) -> Result { // Use LEFT JOIN LATERAL to fetch latest history row (ordered by run_started) efficiently. // Column aliases (latest_* / latest_history_*) are mapped inside TaskCollection::from_row. let job: Job = sqlx::query_as( r#" SELECT tc.*, th.id AS latest_history_id, th.job_id AS latest_history_job_id, th.state AS latest_history_state, th.stats AS latest_history_stats, th.started_at AS latest_history_started_at, th.completed_at AS latest_history_completed_at FROM recon.jobs tc LEFT JOIN LATERAL ( SELECT id, job_id, state, stats, started_at, completed_at FROM recon.job_history WHERE job_id = tc.id ORDER BY started_at DESC LIMIT 1 ) th ON TRUE WHERE tc.id = $1 "#, ) .bind(id) .fetch_one(self.pool()) .await?; Ok(job) } /// List all jobs from the database. pub async fn list_jobs(&self, job_ids: Vec) -> Result> { if !job_ids.is_empty() { let jobs: Vec = sqlx::query_as( r#" SELECT tc.*, th.id AS latest_history_id, th.job_id AS latest_history_job_id, th.state AS latest_history_state, th.stats AS latest_history_stats, th.started_at AS latest_history_started_at, th.completed_at AS latest_history_completed_at FROM recon.jobs tc LEFT JOIN LATERAL ( SELECT id, job_id, state, stats, started_at, completed_at FROM recon.job_history WHERE job_id = tc.id ORDER BY started_at DESC LIMIT 1 ) th ON TRUE WHERE tc.id = ANY($1) "#, ) .bind(&job_ids) .fetch_all(self.pool()) .await?; Ok(jobs) } else { let jobs: Vec = sqlx::query_as( r#" SELECT tc.*, th.id AS latest_history_id, th.job_id AS latest_history_job_id, th.state AS latest_history_state, th.stats AS latest_history_stats, th.started_at AS latest_history_started_at, th.completed_at AS latest_history_completed_at FROM recon.jobs tc LEFT JOIN LATERAL ( SELECT id, job_id, state, stats, started_at, completed_at FROM recon.job_history WHERE job_id = tc.id ORDER BY started_at DESC LIMIT 1 ) th ON TRUE "#, ) .fetch_all(self.pool()) .await?; Ok(jobs) } } /// Upsert a single job into the database. /// /// NOTE: This will NOT update the latest_run/history of the TaskCollection. /// /// # Arguments /// * `job` - The TaskCollection to upsert. pub async fn put_job(&self, job: Job) -> Result { // Upsert single TaskCollection and return the fully populated (including latest history) row. // Using RETURNING keeps it a single round-trip for the ID; we then leverage get_job // for consistent row + lateral latest history. let query = r#" INSERT INTO recon.jobs ( id, name, target_id, definition, schedule, recurrence, created_at ) VALUES ( $1::uuid, $2::text, $3::text, $4::jsonb, $5::text, $6::int, $7::timestamptz ) ON CONFLICT (id) DO UPDATE SET name = COALESCE(EXCLUDED.name, recon.jobs.name), target_id = EXCLUDED.target_id, -- Allow NULL overwrite definition = COALESCE(EXCLUDED.definition, recon.jobs.definition), schedule = COALESCE(EXCLUDED.schedule, recon.jobs.schedule), recurrence = EXCLUDED.recurrence -- allow NULL overwrite RETURNING id "#; let inserted_id: Uuid = sqlx::query_scalar(query) .bind(job.id) .bind(job.name) .bind(job.target_id) .bind(job.definition) .bind(job.schedule) .bind(job.recurrence) .bind(job.created_at) .fetch_one(self.pool()) .await?; let updated_job = self.get_job(&inserted_id).await?; Ok(updated_job) } /// Select all jobs that are eligible to start (not running, recurrence not 0) /// and mark them as running by inserting a history row and decrementing recurrence. Returns the /// started jobs with updated latest history rows. #[cfg(feature = "task_manager")] pub async fn start_due_jobs(&self) -> Result> { let mut tx = self.pool().begin().await?; let now = Utc::now(); // Lock candidate Job rows (not currently running, recurrence not 0) and include latest history via lateral join. // We include the latest history columns so Job::from_row can populate `latest_run`, which is required // for correct schedule evaluation (is_due relies on latest_run.started_at as last_run_time). let candidates: Vec = sqlx::query_as( r#" SELECT tc.*, th.id AS latest_history_id, th.job_id AS latest_history_job_id, th.state AS latest_history_state, th.stats AS latest_history_stats, th.started_at AS latest_history_started_at, th.completed_at AS latest_history_completed_at FROM recon.jobs tc LEFT JOIN LATERAL ( SELECT id, job_id, state, stats, started_at, completed_at FROM recon.job_history WHERE job_id = tc.id ORDER BY started_at DESC LIMIT 1 ) th ON TRUE WHERE (tc.recurrence IS NULL OR tc.recurrence <> 0) AND (th.state IS DISTINCT FROM $1) FOR UPDATE OF tc SKIP LOCKED "#, ) .bind(JobState::Running) .fetch_all(&mut *tx) .await?; if candidates.is_empty() { tx.commit().await?; return Ok(Vec::new()); } // Evaluate schedule for each candidate to determine if due. let mut due: Vec = Vec::new(); for tc in candidates.into_iter() { let (is_due, _next_due) = tc.is_due(now); if is_due { due.push(tc); } } if due.is_empty() { // Nothing actually due after schedule evaluation. tx.commit().await?; return Ok(Vec::new()); } // Build JSON for histories only for due collections. let mut json_arr: Vec = Vec::with_capacity(due.len()); for tc in due.iter() { json_arr.push(json!({ "id": Uuid::now_v7(), "job_id": tc.id, "state": JobState::Running, "stats": {}, "started_at": now, "completed_at": None::>, })); } let jsonb = JsonValue::Array(json_arr); // Insert histories & decrement recurrence. sqlx::query( r#" WITH new_histories AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id uuid, job_id uuid, state job_state, stats jsonb, started_at timestamptz, completed_at timestamptz ) ), inserted AS ( INSERT INTO recon.job_history ( id, job_id, state, stats, started_at, completed_at ) SELECT id, job_id, state, COALESCE(stats, '{}'::jsonb), COALESCE(started_at, now()), completed_at FROM new_histories ) UPDATE recon.jobs tc SET recurrence = CASE WHEN tc.recurrence IS NOT NULL THEN tc.recurrence - 1 ELSE tc.recurrence END, updated_at = now() WHERE tc.id IN (SELECT job_id FROM new_histories) AND tc.recurrence IS NOT NULL AND tc.recurrence > 0; "#, ) .bind(jsonb) .execute(&mut *tx) .await?; // If we do want the latest history rows included in the returned TaskCollections, we would need to requery them here. // For now, we return the due collections without any history populated. // // Requery the started collections to include the fresh latest history row. // let started_ids: Vec = due.iter().map(|tc| tc.id).collect(); // let started: Vec = sqlx::query_as( // r#" // SELECT tc.*, // th.id AS latest_history_id, // th.job_id AS latest_history_job_id, // th.state AS latest_history_state, // th.stats AS latest_history_stats, // th.started_at AS latest_history_started_at, // th.completed_at AS latest_history_completed_at // FROM recon.jobs tc // LEFT JOIN LATERAL ( // SELECT id, job_id, state, stats, started_at, completed_at // FROM recon.job_history // WHERE job_id = tc.id // ORDER BY started_at DESC // LIMIT 1 // ) th ON TRUE // WHERE tc.id = ANY($1) // "#, // ) // .bind(&started_ids) // .fetch_all(&mut *tx) // .await?; tx.commit().await?; Ok(due) } /// Upsert multiple jobs into the database. /// /// NOTE: This will NOT update the latest_run/history of the TaskCollections. /// /// # Arguments /// * `jobs` - A vector of TaskCollections to upsert. /// pub async fn put_jobs(&self, jobs: Vec) -> Result> { if jobs.is_empty() { return Err(Error::Other("No jobs provided to update".into())); } // Convert jobs into a JSONB array and use jsonb_to_recordset for bulk upsert. let mut json_arr: Vec = Vec::with_capacity(jobs.len()); for job in jobs.into_iter() { json_arr.push(json!({ "id": job.id, "name": job.name, "target_id": job.target_id, "definition": job.definition, "schedule": job.schedule, "recurrence": job.recurrence, "created_at": job.created_at, })); } let jsonb = JsonValue::Array(json_arr); let query = r#" WITH new_jobs AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id uuid, name text, target_id text, definition jsonb, schedule text, recurrence int, created_at timestamptz ) ) INSERT INTO recon.jobs ( id, name, target_id, definition, schedule, recurrence, created_at ) SELECT id, COALESCE(name, ''), target_id, COALESCE(definition, '{}'::jsonb), COALESCE(schedule, ''), recurrence, COALESCE(created_at, now()) FROM new_jobs ON CONFLICT (id) DO UPDATE SET name = COALESCE(EXCLUDED.name, recon.jobs.name), target_id = EXCLUDED.target_id, -- Allow NULL overwrite definition = COALESCE(EXCLUDED.definition, recon.jobs.definition), schedule = COALESCE(EXCLUDED.schedule, recon.jobs.schedule), recurrence = EXCLUDED.recurrence --- allow NULL overwrite RETURNING id "#; let ids: Vec = sqlx::query_scalar(query) .bind(jsonb) .fetch_all(self.pool()) .await?; let updated_jobs = self.list_jobs(ids).await?; Ok(updated_jobs) } /// Delete a job by id. /// /// # Arguments /// * `job_id` - The ID of the TaskCollection to delete. pub async fn delete_job(&self, job_id: &Uuid) -> Result { sqlx::query("DELETE FROM jobs WHERE id = $1") .bind(job_id) .execute(self.pool()) .await?; Ok(()) } /// Delete multiple jobs by their IDs. /// /// # Arguments /// * `job_ids` - A vector of TaskCollection IDs to delete pub async fn delete_jobs(&self, job_ids: Vec) -> Result { if job_ids.is_empty() { return Err(Error::Other("No jobs provided to delete".into())); } sqlx::query("DELETE FROM jobs WHERE id = ANY($1)") .bind(&job_ids) .execute(self.pool()) .await?; Ok(()) } } ``` ### `shared/mrpf_core/src/db/recon/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/recon/mod.rs` ```rust mod apn_device; mod domain; mod ipv4address; mod jobs; mod notification; mod target; mod words; ``` ### `shared/mrpf_core/src/db/recon/notification.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/recon/notification.rs` ```rust use crate::DbClient; use crate::recon::{Notification, NotificationCategory}; use crate::{Error, Result}; use uuid::Uuid; impl DbClient { /// Fetch a single notification by id. pub async fn get_notification(&self, id: &Uuid) -> Result { let notification: Notification = sqlx::query_as( r#"SELECT id, title, subtitle, content, is_read, category, target_id, created_at FROM recon.notifications WHERE id = $1"#, ) .bind(id) .fetch_one(self.pool()) .await?; Ok(notification) } /// List notifications with optional filters. /// /// # Arguments /// /// * `notification_ids` - Vector of notification UUIDs to filter by. If empty, no filtering on IDs is applied. /// * `category` - Optional notification category to filter by. /// * `is_read` - Optional boolean to filter notifications by their read status. /// * `target_id` - Optional string to filter notifications by their target ID. /// * `page` - Optional page number for pagination (zero-based). If provided without `records_per_page`, defaults to 100 records per page. /// * `records_per_page` - Optional number of records per page for pagination. If provided without `page`, returns the first `n` records. pub async fn list_notifications( &self, notification_ids: Vec, category: Option, is_read: Option, target_id: Option, page: Option, records_per_page: Option, ) -> Result> { // Calculate pagination parameters let (limit, offset) = match (page, records_per_page) { (Some(p), Some(rpp)) => { // Ensure positive records_per_page let rpp = if rpp <= 0 { 100 } else { rpp }; // Ensure positive page let safe_p = p.max(0); // Calculate offset for pagination let offset = safe_p * rpp; (Some(rpp), Some(offset)) } (Some(p), None) => { // default when page provided without records_per_page let rpp: i64 = 100; // Ensure positive page let safe_p = p.max(0); // Calculate offset for pagination let offset = safe_p * rpp; (Some(rpp), Some(offset)) } (None, Some(rpp)) => { // Ensure positive records_per_page let rpp = if rpp <= 0 { 100 } else { rpp }; // Default offset to 0 when page is not provided (Some(rpp), Some(0)) } (None, None) => (None, None), }; // Build and execute query depending on whether we have a LIMIT. let rows: Vec = if let Some(limit) = limit { let offset = offset.unwrap_or(0); sqlx::query_as( r#" SELECT id, title, subtitle, content, is_read, category, target_id, created_at FROM recon.notifications WHERE (cardinality($1::uuid[]) = 0 OR id = ANY($1)) AND ($2::notification_category IS NULL OR category = $2) AND ($3::bool IS NULL OR is_read = $3) AND ($4::text IS NULL OR target_id = $4) ORDER BY created_at DESC LIMIT $5 OFFSET $6 "#, ) .bind(¬ification_ids) .bind(category) .bind(is_read) .bind(target_id) .bind(limit) .bind(offset) .fetch_all(self.pool()) .await? } else { // No limit requested sqlx::query_as( r#" SELECT id, title, subtitle, content, is_read, category, target_id, created_at FROM recon.notifications WHERE (cardinality($1::uuid[]) = 0 OR id = ANY($1)) AND ($2::notification_category IS NULL OR category = $2) AND ($3::bool IS NULL OR is_read = $3) AND ($4::text IS NULL OR target_id = $4) ORDER BY created_at DESC "#, ) .bind(¬ification_ids) .bind(category) .bind(is_read) .bind(target_id) .fetch_all(self.pool()) .await? }; Ok(rows) } /// Upsert a single notification. /// Facts stored as jsonb map. pub async fn put_notification(&self, notification: &Notification) -> Result { sqlx::query( r#" INSERT INTO recon.notifications (id, title, subtitle, content, is_read, category, target_id, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9) ON CONFLICT (id) DO UPDATE SET title = COALESCE(EXCLUDED.title, recon.notifications.title), subtitle = EXCLUDED.subtitle, content = COALESCE(EXCLUDED.content, recon.notifications.content), is_read = COALESCE(EXCLUDED.is_read, recon.notifications.is_read), category = COALESCE(EXCLUDED.category, recon.notifications.category), target_id = EXCLUDED.target_id "#, ) .bind(notification.id) .bind(¬ification.title) .bind(¬ification.subtitle) .bind(¬ification.content) .bind(notification.is_read) .bind(notification.category) .bind(¬ification.target_id) .bind(notification.created_at) .execute(self.pool()) .await?; Ok(()) } /// Bulk upsert notifications. pub async fn put_notifications(&self, notifications: &[Notification]) -> Result { if notifications.is_empty() { return Err(Error::Other("No notifications provided to update".into())); } use serde_json::{Value as JsonValue, json}; let rows: Vec = notifications .iter() .map(|n| { json!({ "id": n.id, "title": n.title, "subtitle": n.subtitle, "content": n.content, "is_read": n.is_read, "category": n.category, "target_id": n.target_id, "created_at": n.created_at, }) }) .collect(); sqlx::query( r#" WITH input AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id uuid, title text, subtitle text, content text, is_read bool, category notification_category, target_id text, created_at timestamptz ) ) INSERT INTO recon.notifications (id, title, subtitle, content, is_read, category, target_id, created_at) SELECT id, COALESCE(title, ''), subtitle, COALESCE(content, ''), COALESCE(is_read, false), category, target_id, COALESCE(created_at, now()) FROM input ON CONFLICT (id) DO UPDATE SET title = COALESCE(EXCLUDED.title, recon.notifications.title), subtitle = EXCLUDED.subtitle, content = COALESCE(EXCLUDED.content, recon.notifications.content), is_read = COALESCE(EXCLUDED.is_read, recon.notifications.is_read), category = COALESCE(EXCLUDED.category, recon.notifications.category), target_id = EXCLUDED.target_id "#, ) .bind(JsonValue::Array(rows)) .execute(self.pool()) .await?; Ok(()) } /// Mark a notification as read. pub async fn mark_notification_read(&self, id: &Uuid) -> Result { sqlx::query("UPDATE recon.notifications SET is_read = true WHERE id = $1") .bind(id) .execute(self.pool()) .await?; Ok(()) } /// Mark a notification as unread. pub async fn mark_notification_unread(&self, id: &Uuid) -> Result { sqlx::query("UPDATE recon.notifications SET is_read = false WHERE id = $1") .bind(id) .execute(self.pool()) .await?; Ok(()) } /// Delete a single notification. pub async fn delete_notification(&self, id: &Uuid) -> Result { sqlx::query("DELETE FROM recon.notifications WHERE id = $1") .bind(id) .execute(self.pool()) .await?; Ok(()) } /// Delete multiple notifications by id. pub async fn delete_notifications(&self, ids: Vec) -> Result { if ids.is_empty() { return Err(Error::Other( "No notification ids provided to delete".into(), )); } sqlx::query("DELETE FROM recon.notifications WHERE id = ANY($1)") .bind(&ids) .execute(self.pool()) .await?; Ok(()) } } ``` ### `shared/mrpf_core/src/db/recon/target.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/recon/target.rs` ```rust use crate::DbClient; use crate::recon::Target; use crate::{Error, Result}; use serde_json::{Value as JsonValue, json}; /// Convert a HashSet to a sorted Vec /// /// We perform this before storing in the database to help reduce diffs when items are added. /// /// TODO: Move this somewhere common fn to_sorted_vec(set: &std::collections::HashSet) -> Vec { let mut v: Vec = set.iter().cloned().collect(); v.sort(); v } impl DbClient { /// Fetch a target by its name from the database. pub async fn get_target(&self, name: &str) -> Result { let target: Target = sqlx::query_as("SELECT * FROM recon.targets WHERE id = $1") .bind(name) .fetch_one(self.pool()) .await?; Ok(target) } /// List all targets from the database. pub async fn list_targets(&self, target_ids: Vec) -> Result> { if !target_ids.is_empty() { let targets: Vec = sqlx::query_as("SELECT * FROM recon.targets WHERE id = ANY($1)") .bind(&target_ids) .fetch_all(self.pool()) .await?; Ok(targets) } else { let targets: Vec = sqlx::query_as("SELECT * FROM recon.targets") .fetch_all(self.pool()) .await?; Ok(targets) } } pub async fn put_target(&self, target: &Target) -> Result { // Prepare arrays from HashSets let tlds: Vec = to_sorted_vec(&target.tlds); let internal_tlds: Vec = to_sorted_vec(&target.internal_tlds); let fqdns_in_scope: Vec = to_sorted_vec(&target.fqdns_in_scope); let fqdns_out_of_scope: Vec = to_sorted_vec(&target.fqdns_out_of_scope); let urls_in_scope: Vec = to_sorted_vec(&target.urls_in_scope); let urls_out_of_scope: Vec = to_sorted_vec(&target.urls_out_of_scope); let ip_ranges: Vec = target.ip_ranges.iter().map(|p| p.to_string()).collect(); let cert_org_names: Vec = to_sorted_vec(&target.cert_org_names); // Upsert: insert new row, leverage util.array_union_* functions to deduplicate, remove NULLs, // and sort items. Sorting helps reducing actual diffs in the DB when items are added. let query = r#" INSERT INTO recon.targets ( id, tlds, internal_tlds, fqdns_in_scope, fqdns_out_of_scope, urls_in_scope, urls_out_of_scope, ip_ranges, cert_org_names, has_bounties, is_private, favorite ) VALUES ( $1::text, $2::text[], $3::text[], $4::text[], $5::text[], $6::text[], $7::text[], $8::cidr[], $9::text[], COALESCE($10, true), COALESCE($11, true), COALESCE($12, false) ) ON CONFLICT (id) DO UPDATE SET tlds = util.array_union_text (recon.targets.tlds, EXCLUDED.tlds), internal_tlds = util.array_union_text (recon.targets.internal_tlds, EXCLUDED.internal_tlds), fqdns_in_scope = util.array_union_text(recon.targets.fqdns_in_scope, EXCLUDED.fqdns_in_scope), fqdns_out_of_scope = util.array_union_text(recon.targets.fqdns_out_of_scope, EXCLUDED.fqdns_out_of_scope), urls_in_scope = util.array_union_text (recon.targets.urls_in_scope, EXCLUDED.urls_in_scope), urls_out_of_scope = util.array_union_text (recon.targets.urls_out_of_scope, EXCLUDED.urls_out_of_scope), ip_ranges = util.array_union_cidr (recon.targets.ip_ranges, EXCLUDED.ip_ranges), cert_org_names = util.array_union_text(recon.targets.cert_org_names, EXCLUDED.cert_org_names), has_bounties = COALESCE(EXCLUDED.has_bounties, recon.targets.has_bounties), is_private = COALESCE(EXCLUDED.is_private, recon.targets.is_private), favorite = COALESCE(EXCLUDED.favorite, recon.targets.favorite) "#; sqlx::query(query) .bind(&target.id) .bind(&tlds) .bind(&internal_tlds) .bind(&fqdns_in_scope) .bind(&fqdns_out_of_scope) .bind(&urls_in_scope) .bind(&urls_out_of_scope) .bind(&ip_ranges) .bind(&cert_org_names) .bind(target.has_bounties) .bind(target.is_private) .bind(target.favorite) .execute(self.pool()) .await?; Ok(()) } pub async fn put_targets(&self, targets: Vec) -> Result { if targets.is_empty() { return Err(Error::Other("No targets provided to update".into())); } // Convert targets into a JSONB array and use jsonb_to_recordset for bulk upsert. let mut json_arr: Vec = Vec::with_capacity(targets.len()); for target in targets.into_iter() { let obj = json!({ "id": target.id, "tlds": to_sorted_vec(&target.tlds), "internal_tlds": to_sorted_vec(&target.internal_tlds), "cert_org_names": to_sorted_vec(&target.cert_org_names), "ip_ranges": target.ip_ranges.iter().map(|p| p.to_string()).collect::>(), "fqdns_in_scope": to_sorted_vec(&target.fqdns_in_scope), "fqdns_out_of_scope": to_sorted_vec(&target.fqdns_out_of_scope), "urls_in_scope": to_sorted_vec(&target.urls_in_scope), "urls_out_of_scope": to_sorted_vec(&target.urls_out_of_scope), "has_bounties": target.has_bounties, "is_private": target.is_private, "favorite": target.favorite }); json_arr.push(obj); } let jsonb = JsonValue::Array(json_arr); let query = r#" WITH new_targets AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id text, tlds text[], internal_tlds text[], cert_org_names text[], ip_ranges cidr[], fqdns_in_scope text[], fqdns_out_of_scope text[], urls_in_scope text[], urls_out_of_scope text[], has_bounties boolean, is_private boolean, favorite boolean ) ) INSERT INTO recon.targets ( id, tlds, internal_tlds, cert_org_names, ip_ranges, fqdns_in_scope, fqdns_out_of_scope, urls_in_scope, urls_out_of_scope, has_bounties, is_private, favorite ) SELECT id, coalesce(tlds, ARRAY[]::text[]), coalesce(internal_tlds, ARRAY[]::text[]), coalesce(cert_org_names, ARRAY[]::text[]), coalesce(ip_ranges, ARRAY[]::cidr[]), coalesce(fqdns_in_scope, ARRAY[]::text[]), coalesce(fqdns_out_of_scope, ARRAY[]::text[]), coalesce(urls_in_scope, ARRAY[]::text[]), coalesce(urls_out_of_scope, ARRAY[]::text[]), COALESCE(has_bounties, false), COALESCE(is_private, false), COALESCE(favorite, false) FROM new_targets ON CONFLICT (id) DO UPDATE SET tlds = util.array_union_text (recon.targets.tlds, EXCLUDED.tlds), internal_tlds = util.array_union_text (recon.targets.internal_tlds, EXCLUDED.internal_tlds), cert_org_names = util.array_union_text (recon.targets.cert_org_names, EXCLUDED.cert_org_names), ip_ranges = util.array_union_cidr (recon.targets.ip_ranges, EXCLUDED.ip_ranges), fqdns_in_scope = util.array_union_text (recon.targets.fqdns_in_scope, EXCLUDED.fqdns_in_scope), fqdns_out_of_scope = util.array_union_text (recon.targets.fqdns_out_of_scope,EXCLUDED.fqdns_out_of_scope), urls_in_scope = util.array_union_text (recon.targets.urls_in_scope, EXCLUDED.urls_in_scope), urls_out_of_scope = util.array_union_text (recon.targets.urls_out_of_scope, EXCLUDED.urls_out_of_scope), has_bounties = COALESCE(EXCLUDED.has_bounties, recon.targets.has_bounties), is_private = COALESCE(EXCLUDED.is_private, recon.targets.is_private), favorite = COALESCE(EXCLUDED.favorite, recon.targets.favorite) "#; sqlx::query(query).bind(jsonb).execute(self.pool()).await?; Ok(()) } pub async fn mark_target_favorite(&self, target_id: &str) -> Result { sqlx::query("UPDATE recon.targets SET favorite = true WHERE id = $1") .bind(target_id) .execute(self.pool()) .await?; Ok(()) } pub async fn mark_target_unfavorite(&self, target_id: &str) -> Result { sqlx::query("UPDATE recon.targets SET favorite = false WHERE id = $1") .bind(target_id) .execute(self.pool()) .await?; Ok(()) } pub async fn delete_target(&self, target_id: &str) -> Result { sqlx::query("DELETE FROM recon.targets WHERE id = $1") .bind(target_id) .execute(self.pool()) .await?; Ok(()) } pub async fn delete_targets(&self, target_ids: Vec) -> Result { if target_ids.is_empty() { return Err(Error::Other("No targets provided to delete".into())); } sqlx::query("DELETE FROM recon.targets WHERE id = ANY($1)") .bind(&target_ids) .execute(self.pool()) .await?; Ok(()) } } ``` ### `shared/mrpf_core/src/db/recon/words.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/recon/words.rs` ```rust use crate::DbClient; use crate::recon::{TemplateKey, Word, WordCategory}; use crate::{Error, Result}; use std::collections::HashSet; use uuid::Uuid; type TargetId = String; impl DbClient { /// Generate a wordlist for the given parameters /// /// Note: We don't want to return a HashSet as the order might be important in the future. /// /// # Arguments /// * `included_category_names` - List of category names to include (word must have at least one) /// * `excluded_category_names` - List of category names to exclude (word must have none) /// * `target_id` - Optional target_id to filter words by. If provided, only words with matching target_id are /// included. If None, only global words (target_id NULL) are included. pub async fn generate_wordlist( &self, included_category_names: Vec, excluded_category_names: Vec, target_id: Option, ) -> Result> { let words = sqlx::query_as( r#" WITH wc AS ( SELECT w.id, w.name, w.target_id, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.id), NULL), ARRAY[]::uuid[]) AS category_ids, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.name), NULL), ARRAY[]::text[]) AS category_names, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.description), NULL), ARRAY[]::text[]) AS category_descriptions, COALESCE( ( SELECT ARRAY_AGG(DISTINCT tk.name) FROM recon.word_template_key_relation wtk JOIN recon.template_keys tk ON tk.id = wtk.template_key_id WHERE wtk.word_id = w.id ), ARRAY[]::text[] ) AS template_key_names FROM recon.words w LEFT JOIN recon.word_category_relation wcr ON wcr.word_id = w.id LEFT JOIN recon.word_categories c ON c.id = wcr.category_id GROUP BY w.id ) SELECT * FROM wc -- target filtering semantics: -- * if $3 (target_id param) IS NULL -> return ONLY global words (wc.target_id IS NULL) -- * if $3 is NOT NULL -> return ONLY words for that target (wc.target_id = $3) -- Using IS NOT DISTINCT FROM captures exactly that (NULL matches NULL; non-NULL behaves like =) WHERE wc.target_id IS NOT DISTINCT FROM $3 AND ( $1::text[] IS NULL OR $1 = '{}'::text[] OR EXISTS (SELECT 1 FROM unnest($1::text[]) inc WHERE inc = ANY(wc.category_names)) ) AND NOT EXISTS ( SELECT 1 FROM unnest($2::text[]) exc WHERE exc = ANY(wc.category_names) ) ORDER BY wc.name "#, ) .bind(&included_category_names) .bind(&excluded_category_names) .bind(&target_id) .fetch_all(self.pool()) .await?.into_iter().collect(); Ok(words) } /// Retrieve a single word by it's Id pub async fn get_word(&self, id: Uuid) -> Result { let rec: Word = sqlx::query_as( r#" WITH wc AS ( SELECT w.id, w.name, w.target_id, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.id), NULL), ARRAY[]::uuid[]) AS category_ids, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.name), NULL), ARRAY[]::text[]) AS category_names, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.description), NULL), ARRAY[]::text[]) AS category_descriptions, COALESCE( ( SELECT ARRAY_AGG(DISTINCT tk.name) FROM recon.word_template_key_relation wtk JOIN recon.template_keys tk ON tk.id = wtk.template_key_id WHERE wtk.word_id = w.id ), ARRAY[]::text[] ) AS template_key_names FROM recon.words w LEFT JOIN recon.word_category_relation wcr ON wcr.word_id = w.id LEFT JOIN recon.word_categories c ON c.id = wcr.category_id WHERE w.id = $1 GROUP BY w.id ) SELECT * FROM wc "#, ) .bind(id) .fetch_one(self.pool()) .await?; Ok(rec) } /// Upsert multiple words and their relationships to categories and template keys. pub async fn put_words(&self, words: Vec) -> Result { if words.is_empty() { return Err(Error::Other("No words provided to update".into())); } // Single-roundtrip strategy: // Provide three JSON arrays (words, relations word->category, relations word->template_key_name) // Use CTEs to: // 1. Upsert words (by id) // 2. Insert category relations (word_id/category_id) // 3. Insert template key relations resolving template_key_id by name // Assumes categories & template keys already exist (call put_categories / put_template_keys first if needed). use serde_json::Value as JsonValue; use serde_json::json; let mut word_rows: Vec = Vec::with_capacity(words.len()); let mut category_relations: Vec = Vec::new(); let mut template_key_relations: Vec = Vec::new(); for w in &words { word_rows.push(json!({ "id": w.id().to_string(), "name": w.name(), "target_id": w.target_id(), })); for cat in w.categories() { category_relations.push(json!({ "word_id": w.id().to_string(), "category_id": cat.id().to_string(), })); } for tk in w.applicable_template_keys() { template_key_relations.push(json!({ "word_id": w.id().to_string(), "template_key_name": tk, })); } } sqlx::query( r#" WITH word_input AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id uuid, name text, target_id text ) ), upsert_words AS ( INSERT INTO recon.words (id, name, target_id, updated_at) SELECT (id)::uuid, name, target_id, now() FROM word_input ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, target_id = EXCLUDED.target_id, updated_at = now() RETURNING id ), relation_input AS ( SELECT * FROM jsonb_to_recordset($2::jsonb) AS t( word_id uuid, category_id uuid ) ), insert_category_relations AS ( INSERT INTO recon.word_category_relation (word_id, category_id) SELECT word_id, category_id FROM relation_input ON CONFLICT DO NOTHING RETURNING 1 ), word_template_input AS ( SELECT * FROM jsonb_to_recordset($3::jsonb) AS t( word_id uuid, template_key_name text ) ), insert_template_relations AS ( INSERT INTO recon.word_template_key_relation (word_id, template_key_id) SELECT wti.word_id, tk.id FROM word_template_input wti JOIN recon.template_keys tk ON tk.name = wti.template_key_name ON CONFLICT DO NOTHING RETURNING 1 ) SELECT 1 -- execute needs a final SELECT? "#, ) .bind(JsonValue::Array(word_rows)) .bind(JsonValue::Array(category_relations)) .bind(JsonValue::Array(template_key_relations)) .execute(self.pool()) .await?; // TODO: If keywords or categories were not present, does this silently fail? How // can we efficiently check this, preferably without additional DB calls. Ok(()) } /// Retrieve a single word by it's name and optional target_id pub async fn get_word_by_name(&self, name: &str, target_id: Option) -> Result { let rec: Word = sqlx::query_as( r#" WITH wc AS ( SELECT w.id, w.name, w.target_id, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.id), NULL), ARRAY[]::uuid[]) AS category_ids, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.name), NULL), ARRAY[]::text[]) AS category_names, COALESCE(ARRAY_REMOVE(ARRAY_AGG(DISTINCT c.description), NULL), ARRAY[]::text[]) AS category_descriptions, COALESCE( ( SELECT ARRAY_AGG(DISTINCT tk.name) FROM recon.word_template_key_relation wtk JOIN recon.template_keys tk ON tk.id = wtk.template_key_id WHERE wtk.word_id = w.id ), ARRAY[]::text[] ) AS template_key_names FROM recon.words w LEFT JOIN recon.word_category_relation wcr ON wcr.word_id = w.id LEFT JOIN recon.word_categories c ON c.id = wcr.category_id WHERE w.name = $1 AND (w.target_id IS NOT DISTINCT FROM $2) GROUP BY w.id ) SELECT * FROM wc "#, ) .bind(name) .bind(&target_id) .fetch_one(self.pool()) .await?; Ok(rec) } /// List all words pub async fn list_categories( &self, category_ids: Vec, category_names: Vec, ) -> Result> { let rows: Vec = sqlx::query_as( r#" SELECT id, name, description FROM recon.word_categories WHERE (cardinality($1::uuid[]) = 0 OR id = ANY($1)) AND (cardinality($2::text[]) = 0 OR name = ANY($2)) ORDER BY name "#, ) .bind(&category_ids) .bind(&category_names) .fetch_all(self.pool()) .await?; Ok(rows) } /// List all template keys pub async fn list_template_keys( &self, template_key_ids: Vec, template_key_names: Vec, ) -> Result> { let rows: Vec = sqlx::query_as( r#"SELECT id, name FROM recon.template_keys WHERE (cardinality($1::uuid[]) = 0 OR id = ANY($1)) AND (cardinality($2::text[]) = 0 OR name = ANY($2)) ORDER BY name"#, ) .bind(&template_key_ids) .bind(&template_key_names) .fetch_all(self.pool()) .await?; Ok(rows) } /// Delete all given word ids pub async fn delete_words(&self, ids: Vec) -> Result { if ids.is_empty() { return Err(Error::Other("No words provided to delete".into())); } sqlx::query("DELETE FROM recon.words WHERE id = ANY($1)") .bind(&ids) .execute(self.pool()) .await?; Ok(()) } /// Delete all words with the given names pub async fn delete_words_by_name_and_target( &self, names: Vec, target_id: Option, ) -> Result { if names.is_empty() { return Err(Error::Other("No words provided to delete".into())); } if let Some(tid) = target_id { // If target_id is provided, delete words matching both name and target_id (including NULL) sqlx::query( r#" DELETE FROM recon.words WHERE name = ANY($1) AND (target_id IS NOT DISTINCT FROM $2) "#, ) .bind(&names) .bind(&tid) .execute(self.pool()) .await?; Ok(()) } else { // If no target_id, delete all words matching any of the names sqlx::query("DELETE FROM recon.words WHERE name = ANY($1)") .bind(&names) .execute(self.pool()) .await?; Ok(()) } } /// Delete all given categories by their names pub async fn delete_word_categories_by_name(&self, names: Vec) -> Result { if names.is_empty() { return Err(Error::Other("No categories provided to delete".into())); } sqlx::query("DELETE FROM recon.word_categories WHERE name = ANY($1)") .bind(&names) .execute(self.pool()) .await?; Ok(()) } /// Delete all given categories by their IDs pub async fn delete_word_categories(&self, ids: Vec) -> Result { if ids.is_empty() { return Err(Error::Other("No categories provided to delete".into())); } sqlx::query("DELETE FROM recon.word_categories WHERE id = ANY($1)") .bind(&ids) .execute(self.pool()) .await?; Ok(()) } /// Upsert categories (conflict on name updates description & updated_at) pub async fn put_word_categories(&self, categories: Vec) -> Result { if categories.is_empty() { return Err(Error::Other("No categories provided to update".into())); } use serde_json::{Value as JsonValue, json}; let rows: Vec = categories .into_iter() .map(|c| { json!({ "id": c.id().to_string(), "name": c.name(), "description": c.description() }) }) .collect(); sqlx::query( r#" WITH input AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id uuid, name text, description text ) ) INSERT INTO recon.word_categories (id, name, description, updated_at) SELECT (id)::uuid, name, description, now() FROM input ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description, updated_at = now() "#, ) .bind(JsonValue::Array(rows)) .execute(self.pool()) .await?; Ok(()) } /// Upsert template keys (conflict on name updates description & updated_at) pub async fn put_template_keys(&self, template_keys: Vec) -> Result { if template_keys.is_empty() { return Err(Error::Other("No template keys provided to update".into())); } use serde_json::{Value as JsonValue, json}; let rows: Vec = template_keys .into_iter() .map(|c| { json!({ "id": c.id().to_string(), "name": c.name(), "description": c.description() }) }) .collect(); sqlx::query( r#" WITH input AS ( SELECT * FROM jsonb_to_recordset($1::jsonb) AS t( id uuid, name text, description text ) ) INSERT INTO recon.template_keys (id, name, description, updated_at) SELECT gen_random_uuid(), name, description, now() FROM input ON CONFLICT (name) DO UPDATE SET description = EXCLUDED.description, updated_at = now() "#, ) .bind(JsonValue::Array(rows)) .execute(self.pool()) .await?; Ok(()) } pub async fn delete_template_keys(&self, ids: Vec) -> Result { if ids.is_empty() { return Err(Error::Other("No template keys provided to delete".into())); } sqlx::query("DELETE FROM recon.template_keys WHERE id = ANY($1)") .bind(&ids) .execute(self.pool()) .await?; Ok(()) } /// Delete all given template keys by their names pub async fn delete_template_keys_by_name(&self, names: Vec) -> Result { if names.is_empty() { return Err(Error::Other("No template keys provided to delete".into())); } sqlx::query("DELETE FROM recon.template_keys WHERE name = ANY($1)") .bind(&names) .execute(self.pool()) .await?; Ok(()) } } ``` ### `shared/mrpf_core/src/db/tasks/data.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/tasks/data.rs` ```rust use crate::recon::Ipv4Address; use crate::{DbClient, Error, Result}; use serde_json::Value; use sqlx::{FromRow, Row}; use std::collections::HashMap; use std::collections::HashSet; use std::net::Ipv4Addr; use crate::tasks::models::{TaskContainerId, TaskId}; use crate::templates::{DataKind, DataMap}; use tracing::debug; impl DbClient { /// Aggregate all string values stored under `key` from all tasks in the specified container. //// /// This is a convenience method that calls `aggregate_scalar_data` with `DataKind::SetOfStrings`. //// It will ensure the result is unique and store it under `destination_key` for the `root_container_id`. //// # Arguments /// * `aggregate_container_id` - The container id of the tasks to aggregate data from /// * `key` - The key under which the data to aggregate is stored in the /// tasks in the `aggregate_container_id` /// * `destination_key` - The key under which to store the aggregated data pub async fn aggregate_set_of_strings( &self, aggregate_container_id: &TaskContainerId, key: &str, destination_key: &str, ) -> Result { self.aggregate_scalar_data( aggregate_container_id, key, destination_key, DataKind::SetOfStrings(HashSet::new()), ) .await } /// Aggregate all DNS records stored under `key` from all tasks in the specified container. /// /// This method will fetch all DNS record data from the tasks in the specified container, /// merge them into a single map, and store the aggregated result under `destination_key` /// /// # Arguments /// * `aggregate_container_id` - The container id of the tasks to aggregate data from /// * `key` - The key under which the data to aggregate is stored in the /// tasks in the `aggregate_container_id` /// * `destination_key` - The key under which to store the aggregated data pub async fn aggregate_map_of_dns_records( &self, aggregate_container_id: &TaskContainerId, key: &str, destination_key: &str, ) -> Result { let mut tx = self.pool().begin().await?; let root_container_id: TaskContainerId = sqlx::query_scalar( r#" WITH RECURSIVE root_cte AS ( SELECT id, parent_id FROM task_manager.task_containers WHERE id = $1 UNION ALL SELECT tc.id, tc.parent_id FROM task_manager.task_containers tc INNER JOIN root_cte r ON tc.id = r.parent_id ) SELECT id FROM root_cte WHERE parent_id IS NULL LIMIT 1 "#, ) .bind(aggregate_container_id) .fetch_one(&mut *tx) .await?; // We perform the merge fully inside PostgreSQL to minimize round trips. // Strategy (single query after root resolution): // 1. Collect existing aggregated value (if any) + all task_data values for the provided key. // 2. Expand JSON structure: domain -> record_type -> individual record_value. // 3. Aggregate DISTINCT record_value per (domain, record_type) producing arrays. // 4. Reconstruct nested JSON object: { domain: { record_type: [values...] } }. // 5. Upsert aggregated map into task_collection_data with the MapOfDnsRecords kind. // 6. Return number of distinct domains aggregated. let data_kind_name = DataKind::MapOfDnsRecords(HashMap::new()).to_string(); let row = sqlx::query( r#" WITH src AS ( SELECT value FROM task_manager.task_collection_data WHERE id = $3 AND key = $4 UNION ALL SELECT td.value FROM task_manager.task_data td WHERE td.key = $2 AND td.id IN ( SELECT t.id FROM task_manager.tasks t WHERE t.task_container_id = $1 ) ), domains AS ( SELECT (jsonb_each(src.value)).key AS domain, (jsonb_each(src.value)).value AS records_obj FROM src ), flat AS ( SELECT domain, (jsonb_each(records_obj)).key AS record_type, jsonb_array_elements_text((jsonb_each(records_obj)).value) AS record_value FROM domains ), agg AS ( SELECT domain, record_type, array_agg(DISTINCT record_value) AS values FROM flat GROUP BY domain, record_type ), nested AS ( SELECT domain, jsonb_object_agg(record_type, to_jsonb(values)) AS records_obj FROM agg GROUP BY domain ), final AS ( SELECT COALESCE(jsonb_object_agg(domain, records_obj), '{}'::jsonb) AS value, COUNT(DISTINCT domain) AS domain_count FROM nested ), upserted AS ( INSERT INTO task_manager.task_collection_data (id, key, value, kind) SELECT $3, $4, final.value, $5::data_kind FROM final ON CONFLICT (id, key) DO UPDATE SET value = EXCLUDED.value, kind = EXCLUDED.kind RETURNING value ) SELECT domain_count::BIGINT FROM final "#, ) .bind(aggregate_container_id) // $1 .bind(key) // $2 .bind(root_container_id) // $3 .bind(destination_key) // $4 .bind(&data_kind_name) // $5 .fetch_one(&mut *tx) .await?; let map_length: i64 = row.try_get("domain_count")?; // alias domain_count::BIGINT debug!( root_container_id = %root_container_id, aggregate_container_id = %aggregate_container_id, key = %key, destination_key = %destination_key, data_kind = %data_kind_name, domain_count = %map_length, "Aggregated DNS records into task collection" ); tx.commit().await?; Ok(map_length) } /// Aggregate a set of IPv4 addresses stored under `key` from all tasks in the specified container. /// /// When duplicate IPv4Addresses are found, we will merge all the existing tcp and udp ports for that IP address and store /// it as a single IPv4Address entry in the aggregated result. /// /// The end result is the destination_key will contain a SetOfIpv4Addresses where each Ipv4Address has all unique ports merged. /// /// # Arguments /// * `aggregate_container_id` - The container id of the tasks to aggregate data from /// * `key` - The key under which the data to aggregate is stored in the /// tasks in the `aggregate_container_id` /// * `destination_key` - The key under which to store the aggregated data. This key might already exist, in which case it will be updated. /// /// # Returns the number of unique IPv4 addresses after aggregation. pub async fn aggregate_set_of_ipv4_addresses( &self, aggregate_container_id: &TaskContainerId, key: &str, destination_key: &str, ) -> Result { // Strategy (DB round trips kept minimal): // 1. Begin TX // 2. Fetch root_container_id (same recursive CTE used elsewhere) // 3. Fetch (a) existing aggregated destination value (if any) and (b) all task data values for the key // in ONE query using UNION ALL. // 4. Merge in Rust, grouping by IP and unioning tcp/udp port sets (HashSet), keeping first target_id. // 5. Upsert aggregated result back into task_collection_data. // 6. Commit TX and return unique IP count. let mut tx = self.pool().begin().await?; // Step 2: root container id resolution let root_container_id: TaskContainerId = sqlx::query_scalar( r#" WITH RECURSIVE root_cte AS ( SELECT id, parent_id FROM task_manager.task_containers WHERE id = $1 UNION ALL SELECT tc.id, tc.parent_id FROM task_manager.task_containers tc INNER JOIN root_cte r ON tc.id = r.parent_id ) SELECT id FROM root_cte WHERE parent_id IS NULL LIMIT 1 "#, ) .bind(aggregate_container_id) .fetch_one(&mut *tx) .await?; // Step 3: Fetch existing aggregated value + all task data values for the provided key. // We only need the raw JSON (value) column. let rows = sqlx::query( r#" SELECT value FROM ( SELECT value, 0 AS priority FROM task_manager.task_collection_data WHERE id = $1 AND key = $2 UNION ALL SELECT td.value, 1 AS priority FROM task_manager.task_data td WHERE td.key = $4 AND td.id IN ( SELECT t.id FROM task_manager.tasks t WHERE t.task_container_id = $3 ) ) s ORDER BY priority "#, ) .bind(root_container_id) .bind(destination_key) .bind(aggregate_container_id) .bind(key) .fetch_all(&mut *tx) .await?; debug!( rows_len = %rows.len(), root_container_id = %root_container_id, aggregate_container_id = %aggregate_container_id, "Fetched rows for aggregation of IPv4 addresses" ); // Step 4: Merge in Rust leveraging model merge logic. let mut agg: HashMap = HashMap::new(); for row in rows { let v: Value = row.try_get("value")?; match v { Value::Array(arr) => { for elem in arr.into_iter() { if elem.is_object() { match serde_json::from_value::(elem) { Ok(addr) => { if let Err(Error::ModelMergeError(msg)) = Ipv4Address::merge_into_map(&mut agg, addr) { return Err(Error::Other(msg)); } } Err(e) => { debug!(error = %e, "Failed to deserialize IPv4Address entry during aggregation"); } } } } } Value::Object(_) => match serde_json::from_value::(v) { Ok(addr) => { if let Err(Error::ModelMergeError(msg)) = Ipv4Address::merge_into_map(&mut agg, addr) { return Err(Error::Other(msg)); } } Err(e) => { debug!(error = %e, "Failed to deserialize single IPv4Address during aggregation"); } }, _ => { debug!("Ignoring unexpected JSON value kind during IPv4 aggregation"); } } } let unique_count = agg.len() as i64; let final_set: HashSet = agg.into_values().collect(); // Step 5: Upsert aggregated result. We only store the raw data array (see put_collection_data). let data_kind = DataKind::SetOfIpv4Addresses(final_set); let data_kind_name = data_kind.to_string(); let mut json_data = serde_json::to_value(data_kind) .map_err(|e| Error::Other(format!("Failed to serialize aggregated IPv4 set: {e}")))?; sqlx::query( r#" INSERT INTO task_manager.task_collection_data (id, key, value, kind) VALUES ($1, $2, $3, $4::data_kind) ON CONFLICT (id, key) DO UPDATE SET value = EXCLUDED.value, kind = EXCLUDED.kind "#, ) .bind(root_container_id) .bind(destination_key) .bind(json_data["data"].take()) .bind(&data_kind_name) .execute(&mut *tx) .await?; tx.commit().await?; Ok(unique_count) } /// Aggregate all scalar values stored under `key` from all tasks in the specified container. /// /// This method will work for data kinds that represent scalar sets, such as: /// SetOfStrings(HashSet) /// SetOfIntegers(HashSet) /// SetOfFqdns(HashSet) /// SetOfPorts(HashSet) /// /// It will ensure the result is unique and store it under `destination_key` for the `root_container_id`. /// TBD, do we want to remove the data thats aggregated from the source tasks? /// /// When we are dealing with simple scalar sets, we can do the aggregation directly in SQL. For more complex /// types, we will fetch the data and aggregate in Rust. /// /// # Arguments /// * `aggregate_container_id` - The container id of the tasks to aggregate data from /// * `key` - The key under which the data to aggregate is stored in the /// tasks in the `aggregate_container_id` /// * `destination_key` - The key under which to store the aggregated data /// * `data_kind` - The kind of data being aggregated (e.g., SetOfStrings, SetOfIntegers) pub async fn aggregate_scalar_data( &self, aggregate_container_id: &TaskContainerId, key: &str, destination_key: &str, data_kind: DataKind, // TODO: This is a bit sub-optimal as we don't want any data, just the variant. ) -> Result { let mut tx = self.pool().begin().await?; // TODO: Can we combine this into a single query with the insert? // Retrieve the root container id for the given aggregate_container_id let root_container_id: TaskContainerId = sqlx::query_scalar( r#" WITH RECURSIVE root_cte AS ( SELECT id, parent_id FROM task_manager.task_containers WHERE id = $1 UNION ALL SELECT tc.id, tc.parent_id FROM task_manager.task_containers tc INNER JOIN root_cte r ON tc.id = r.parent_id ) SELECT id FROM root_cte WHERE parent_id IS NULL LIMIT 1 "#, ) .bind(aggregate_container_id) .fetch_one(&mut *tx) .await?; // Insert or update the aggregated data in the task_collection_data table and return the set length let row = sqlx::query( r#" WITH upserted AS ( INSERT INTO task_manager.task_collection_data (id, key, value, kind) SELECT $3, $4, CASE WHEN EXISTS ( SELECT 1 FROM task_manager.task_collection_data WHERE id = $3 AND key = $4 ) THEN ( -- If there's existing data for the destination key, include it in the aggregation SELECT jsonb_agg(DISTINCT elem) FROM ( SELECT jsonb_array_elements(value) AS elem FROM task_manager.task_collection_data WHERE id = $3 AND key = $4 UNION SELECT jsonb_array_elements(value) AS elem -- Changed: Expand task_data arrays here FROM task_manager.task_data WHERE key = $2 AND id IN ( SELECT id FROM tasks WHERE task_container_id = $1 ) ) AS all_values ) ELSE ( -- If the destination key does not exist yet, just aggregate from task_data SELECT jsonb_agg(DISTINCT elem) -- Changed: Aggregate expanded elements, not whole values FROM ( SELECT jsonb_array_elements(value) AS elem -- Changed: Expand task_data arrays here FROM task_manager.task_data WHERE key = $2 AND id IN ( SELECT id FROM task_manager.tasks WHERE task_container_id = $1 ) ) AS all_values ) END AS value, $5::data_kind AS kind ON CONFLICT (id, key) DO UPDATE SET value = EXCLUDED.value RETURNING value ) SELECT COALESCE(jsonb_array_length(value), 0)::BIGINT AS set_length FROM upserted "#, ) .bind(aggregate_container_id) .bind(key) .bind(root_container_id) .bind(destination_key) .bind(data_kind.to_string()) .fetch_one(&mut *tx) .await?; let set_length: i64 = row.try_get("set_length")?; tx.commit().await?; Ok(set_length) } /// Fetch data stored under `key` for the specified `root_container_id`. pub async fn get_data( &self, root_container_id: &TaskContainerId, key: &str, ) -> Result { // Query the task_collection_data table for the given root_container_id and key // // NOTE: we cast kind as text so we can read it as a string and convert to DataKind in our DataKind.from_row() method let row = sqlx::query( r#"SELECT value, kind::text AS kind FROM task_manager.task_collection_data WHERE id = $1 AND key = $2"#, ) .bind(root_container_id) .bind(key) .fetch_optional(self.pool()) .await?; let row = match row { Some(row) => row, None => return Err(Error::Other("No data found".into())), }; Ok(DataKind::from_row(&row)?) } pub async fn get_data_map( &self, root_container_id: &TaskContainerId, keys: HashSet, _keys_with_path: HashSet<(String, String)>, ) -> Result { // Query for all keys in the provided list debug!(root_container_id = %root_container_id, keys = ?keys, "Fetching data map for task collection"); let rows = sqlx::query( r#"SELECT key, value, kind::text AS kind FROM task_manager.task_collection_data WHERE id = $1 AND key = ANY($2)"#, ) .bind(root_container_id) .bind(keys.iter().cloned().collect::>()) .fetch_all(self.pool()) .await?; let mut map = DataMap::new(); for row in rows { let key: String = row.try_get("key")?; map.insert(key, DataKind::from_row(&row)?); } // Ensure all requested keys are present in the result if keys.len() != map.len() { let fetched_keys: std::collections::HashSet<&str> = map.keys().map(|k| k.as_str()).collect(); let missing_keys: Vec = keys .iter() .filter(|k| !fetched_keys.contains(k.as_str())) .cloned() .collect(); return Err(Error::Other(format!( "Missing key(s) in data map: {:?}", missing_keys ))); } Ok(map) } /// Store or replace data under `key` for the task collection. /// /// # Arguments /// * `root_container_id` - The root container id of the task collection to store /// data for /// * `key` - The key under which to store the data /// * `data` - The data to store. Since the data can be quite large, we take ownership of it. pub async fn put_collection_data( &self, root_container_id: &TaskContainerId, key: &str, data: DataKind, ) -> Result { let data_kind_name = data.to_string(); let mut json_data = serde_json::to_value(data).map_err(|e| { Error::Other(format!( "Failed to serialize DataKind to JSON for key '{}': {}", key, e )) })?; debug!(root_container_id = %root_container_id, key = %key, data_kind = %data_kind_name, "Updating task collectiondata"); sqlx::query( r#" INSERT INTO task_manager.task_collection_data (id, key, value, kind) VALUES ($1, $2, $3, $4::data_kind) ON CONFLICT (id, key) DO UPDATE SET value = EXCLUDED.value, kind = EXCLUDED.kind "#, ) .bind(root_container_id) .bind(key) .bind(json_data["data"].take()) .bind(data_kind_name) .execute(self.pool()) .await?; Ok(()) } /// Store or replace data under `key` for the given task. /// /// # Arguments /// * `task_id` - The id of the task to store data for /// * `key` - The key under which to store the data /// * `data` - The data to store. Since the data can be quite large, we take ownership of it. pub async fn put_task_data(&self, task_id: &TaskId, key: &str, data: DataKind) -> Result { let data_kind_name = data.to_string(); let mut json_data = serde_json::to_value(data).map_err(|e| { Error::Other(format!( "Failed to serialize DataKind to JSON for key '{}': {}", key, e )) })?; debug!(task_id = %task_id, key = %key, data_kind = %data_kind_name, "Updating task data"); sqlx::query( r#" INSERT INTO task_manager.task_data (id, key, value, kind) VALUES ($1, $2, $3, $4::data_kind) ON CONFLICT (id, key) DO UPDATE SET value = EXCLUDED.value, kind = EXCLUDED.kind "#, ) .bind(task_id) .bind(key) .bind(json_data["data"].take()) .bind(data_kind_name) .execute(self.pool()) .await?; Ok(()) } /// Store or replace multiple data entries for the task collection. /// /// # Arguments /// * `root_container_id` - The root container id of the task collection to store /// data for /// * `data_map` - The map of key-value pairs to store. Since the data can be quite large, we take ownership of it. pub async fn put_collection_data_map( &self, root_container_id: &TaskContainerId, data_map: DataMap, ) -> Result { let mut tx = self.pool().begin().await?; // Batch insert/replace using a single query for all entries in data_map // Build the VALUES clause dynamically if !data_map.is_empty() { // Prepare vectors for keys, values, and kinds let mut keys = Vec::with_capacity(data_map.len()); let mut values = Vec::with_capacity(data_map.len()); let mut kinds = Vec::with_capacity(data_map.len()); for (key, data) in data_map { let data_kind_name = data.to_string(); let mut json_data = serde_json::to_value(data).map_err(|e| { Error::Other(format!( "Failed to serialize DataKind to JSON for key '{}': {}", key, e )) })?; keys.push(key.clone()); values.push(json_data["data"].take()); kinds.push(data_kind_name); } // Use UNNEST to insert all rows at once sqlx::query( r#" INSERT INTO task_manager.task_collection_data (id, key, value, kind) SELECT $1, key, value, kind FROM UNNEST($2::text[], $3::jsonb[], $4::data_kind[]) AS t(key, value, kind) ON CONFLICT (id, key) DO UPDATE SET value = EXCLUDED.value, kind = EXCLUDED.kind "#, ) .bind(root_container_id) .bind(&keys) .bind(&values) .bind(&kinds) .execute(&mut *tx) .await?; } tx.commit().await?; Ok(()) } /// Store or replace multiple data entries for the task. /// /// # Arguments /// * `task_id` - The id of the task to store data for /// * `data_map` - The map of key-value pairs to store. Since the data can be quite large, we take ownership of it. pub async fn put_task_data_map(&self, task_id: &TaskId, data_map: DataMap) -> Result { let mut tx = self.pool().begin().await?; // Batch insert/replace using a single query for all entries in data_map // Build the VALUES clause dynamically if !data_map.is_empty() { // Prepare vectors for keys, values, and kinds let mut keys = Vec::with_capacity(data_map.len()); let mut values = Vec::with_capacity(data_map.len()); let mut kinds = Vec::with_capacity(data_map.len()); for (key, data) in data_map { let data_kind_name = data.to_string(); let mut json_data = serde_json::to_value(data).map_err(|e| { Error::Other(format!( "Failed to serialize DataKind to JSON for key '{}': {}", key, e )) })?; keys.push(key.clone()); values.push(json_data["data"].take()); kinds.push(data_kind_name); } // Use UNNEST to insert all rows at once sqlx::query( r#" INSERT INTO task_manager.task_data (id, key, value, kind) SELECT $1, key, value, kind FROM UNNEST($2::text[], $3::jsonb[], $4::data_kind[]) AS t(key, value, kind) ON CONFLICT (id, key) DO UPDATE SET value = EXCLUDED.value, kind = EXCLUDED.kind "#, ) .bind(task_id) .bind(&keys) .bind(&values) .bind(&kinds) .execute(&mut *tx) .await?; } tx.commit().await?; Ok(()) } } ``` ### `shared/mrpf_core/src/db/tasks/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/tasks/mod.rs` ```rust mod data; mod tasks; ``` ### `shared/mrpf_core/src/db/tasks/tasks.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/db/tasks/tasks.rs` ```rust use crate::recon::JobState; use crate::tasks::models::{ Task, TaskCollection, TaskContainer, TaskContainerChild, TaskContainerId, TaskId, TaskState, }; use crate::templates::DataMap; use crate::{DbClient, Error, Result}; use sqlx::types::chrono::Utc; use sqlx::{FromRow, Row}; use std::collections::HashMap; use std::collections::HashSet; use tracing::{debug, error}; use uuid::Uuid; impl DbClient { /// List all root container tasks. pub async fn list_root_containers(&self) -> Result> { let rows = sqlx::query("SELECT * FROM task_manager.task_containers WHERE parent_id IS NULL") .fetch_all(self.pool()) .await?; let mut containers = Vec::with_capacity(rows.len()); for row in rows { let container = TaskContainer::from_row(&row)?; containers.push(container); } Ok(containers) } /// Retrieve a full TaskCollection starting from the specified root container ID. /// /// This function recursively retrieves all task containers and associated tasks based on the /// provided task_container_id, including tasks in nested containers and the containers /// themselves. /// /// # Arguments /// * `task_container_id` - The ID of the task container to list child tasks for. /// /// Returns a vector of tuples, each containing a TaskContainer and its directly associated tasks. pub async fn get_task_collection( &self, root_container_id: &TaskContainerId, ) -> Result { // TODO: Investigate ltree extension, this could potentially make our recursive queries simpler/faster. let mut tx = self.pool().begin().await?; // Ensure a consistent snapshot for the duration of this transaction sqlx::query("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY") .execute(&mut *tx) .await?; // 1) Recursively collect the subtree of containers (root + all descendants). // We grab full rows so TaskContainer::from_row can be used, but we also // read parent_id and execution_order from the row for ordering/assembly. debug!(root_container_id = %root_container_id, "Fetching container subtree..."); let container_rows = sqlx::query( r#" WITH RECURSIVE ct AS ( SELECT id FROM task_manager.task_containers WHERE id = $1 UNION ALL SELECT c.id FROM task_manager.task_containers c JOIN ct ON c.parent_id = ct.id ) SELECT tc.* FROM task_manager.task_containers tc JOIN ct ON ct.id = tc.id "#, ) .bind(root_container_id) .fetch_all(&mut *tx) .await?; debug!( root_container_id = %root_container_id, container_count = %container_rows.len(), "Fetched container subtree", ); // 2) Fetch all tasks that live inside any container in that subtree. debug!(root_container_id = %root_container_id, "Fetching tasks in subtree..."); let task_rows = sqlx::query( r#" WITH RECURSIVE ct AS ( SELECT id FROM task_manager.task_containers WHERE id = $1 UNION ALL SELECT c.id FROM task_manager.task_containers c JOIN ct ON c.parent_id = ct.id ) SELECT t.* FROM task_manager.tasks t WHERE t.task_container_id IN (SELECT id FROM ct) "#, ) .bind(root_container_id) .fetch_all(&mut *tx) .await?; debug!( root_container_id = %root_container_id, task_count = %task_rows.len(), "Fetched tasks in subtree", ); tx.commit().await?; // --- Build indices for fast assembly --- let mut containers_by_id = HashMap::with_capacity(container_rows.len()); let mut child_containers_by_parent: HashMap, Uuid)>> = HashMap::new(); // We read parent_id / execution_order directly from the row to guide assembly, // not to store them in the structs. debug!(root_container_id = %root_container_id, "Serializing containers from database..."); for row in &container_rows { let container = TaskContainer::from_row(row)?; let id: Uuid = row.try_get("id")?; let parent_id: Option = row.try_get("parent_id")?; let exec_order: Option = row.try_get("execution_order")?; containers_by_id.insert(id, container); if let Some(pid) = parent_id { child_containers_by_parent .entry(pid) .or_default() .push((exec_order, id)); } } // Map tasks by their container, keeping their execution_order for sorting debug!(root_container_id = %root_container_id, "Serializing tasks from database..."); let mut tasks_by_container: HashMap, Task)>> = HashMap::new(); for row in &task_rows { let task = Task::from_row(row)?; let container_id = row.try_get("task_container_id")?; let exec_order = row.try_get("execution_order")?; tasks_by_container .entry(container_id) .or_default() .push((exec_order, task)); } // Helper: convert Option to a sortable key with NULLS LAST. // This means that any task/container with NULL order will appear at the end. #[inline] fn sort_key(order: Option) -> i32 { order.map(|v| v as i32).unwrap_or(i32::MAX) } // Recursively assemble the full tree starting at the requested root container. fn assemble( container_id: Uuid, containers_by_id: &HashMap, child_containers_by_parent: &HashMap, Uuid)>>, tasks_by_container: &HashMap, Task)>>, ) -> Result { let mut container = containers_by_id .get(&container_id) .ok_or_else(|| Error::Other("container id present in map".to_string()))? .clone(); // Gather child tasks (for this container) let mut merged = Vec::new(); if let Some(ts) = tasks_by_container.get(&container_id) { for (order, task) in ts { // Use Task::id() for deterministic tie-breaking in sort. let tid = *task.id(); merged.push(( sort_key(*order), tid, TaskContainerChild::Task(task.clone()), )); } } // Gather child containers (direct descendants) if let Some(cs) = child_containers_by_parent.get(&container_id) { for (order, child_id) in cs { let child = assemble( *child_id, containers_by_id, child_containers_by_parent, tasks_by_container, )?; // Use child container's id for deterministic tie-breaking // let cid = *child.parent_id().expect("child container has parent"); // Not sure If i understand the comment above correctly, but lets see if it works with the child id instead let cid = *child.id(); merged.push((sort_key(*order), cid, TaskContainerChild::Container(child))); } } // Now unify order: smaller execution_order first; NULLs at the end. // If multiple items share the same order, use the UUID as a stable tiebreaker. merged.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1))); let children = merged.into_iter().map(|(_, _, item)| item).collect(); container.replace_children(children); Ok(container) } // Build the root let root_uuid: Uuid = *root_container_id; if !containers_by_id.contains_key(&root_uuid) { return Err(Error::Other(format!( "Root container {} not found", root_uuid ))); } debug!(root_container_id = %root_container_id, "Assembling full task collection..."); let root = assemble( root_uuid, &containers_by_id, &child_containers_by_parent, &tasks_by_container, )?; TaskCollection::new(root) } /// Create a TaskCollection by extracting and inserting all containers and tasks. pub async fn create_task_collection(&self, task_collection: TaskCollection) -> Result { let mut tx = self.pool().begin().await?; sqlx::query("SET TRANSACTION ISOLATION LEVEL READ COMMITTED, READ WRITE") .execute(&mut *tx) .await?; let root = task_collection.into_root_container(); // We'll do two passes: containers first (top-down), then tasks. let mut container_stack = vec![(root, None, None)]; // (node, parent_id, exec_order) let mut containers_to_insert = Vec::new(); let mut tasks_to_insert = Vec::new(); while let Some((container, parent_id, exec_order)) = container_stack.pop() { let container_id = *container.id(); containers_to_insert.push(( container_id, parent_id, container.execution_mode(), exec_order, )); for (i, child) in container.into_children().enumerate() { let order = i as i16; match child { TaskContainerChild::Task(task) => { // TODO: can we just use Task here and get it to properly serialize into JSONB etc? tasks_to_insert.push(( *task.id(), task.state(), container_id, // Use the ID of the container we're currently inserting into task.completion_policy().clone(), Some(order), serde_json::to_value(task.definition())?, task.timeout_secs() as i32, )); } TaskContainerChild::Container(container) => { container_stack.push((container, Some(container_id), Some(order))); } } } } // Insert containers if !containers_to_insert.is_empty() { // Collect values for SQL array binding. The tuples in // `containers_to_insert` already own their Uuid and enum values, // so collect them by value (no extra cloned() adapter required). let ids: Vec<_> = containers_to_insert.iter().map(|c| c.0).collect(); let parent_ids: Vec<_> = containers_to_insert.iter().map(|c| c.1).collect(); // TaskExecutionMode and Option are Copy-like here, collect by value let exec_modes: Vec<_> = containers_to_insert.iter().map(|c| c.2).collect(); let exec_orders: Vec<_> = containers_to_insert.iter().map(|c| c.3).collect(); sqlx::query( r#" INSERT INTO task_manager.task_containers (id, parent_id, execution_mode, execution_order) SELECT * FROM UNNEST($1::uuid[], $2::uuid[], $3::text[]::task_execution_mode[], $4::smallint[]) "#, ) .bind(&ids) .bind(&parent_ids) .bind(&exec_modes) .bind(&exec_orders) .execute(&mut *tx) .await?; } // Insert tasks if !tasks_to_insert.is_empty() { // Collect owned values for tasks insertion (take values from the tuples) let t_ids: Vec<_> = tasks_to_insert.iter().map(|t| t.0).collect(); let t_states: Vec<_> = tasks_to_insert.iter().map(|t| t.1).collect(); let t_container_ids: Vec<_> = tasks_to_insert.iter().map(|t| t.2).collect(); let t_policies: Vec<_> = tasks_to_insert.iter().map(|t| t.3.clone()).collect(); let t_exec_orders: Vec<_> = tasks_to_insert.iter().map(|t| t.4).collect(); let t_snapshots: Vec<_> = tasks_to_insert.iter().map(|t| t.5.clone()).collect(); let t_timeouts: Vec<_> = tasks_to_insert.iter().map(|t| t.6).collect(); sqlx::query( r#" INSERT INTO task_manager.tasks (id, state, task_container_id, completion_policy, execution_order, definition, timeout) SELECT * FROM UNNEST( $1::uuid[], $2::text[]::task_state[], $3::uuid[], $4::text[]::completion_policy[], $5::smallint[], $6::jsonb[], $7::int[] ) "#, ) .bind(&t_ids) .bind(&t_states) .bind(&t_container_ids) .bind(&t_policies) .bind(&t_exec_orders) .bind(&t_snapshots) .bind(&t_timeouts) .execute(&mut *tx) .await?; } tx.commit().await?; Ok(()) } /// Create tasks /// /// # Arguments /// * `tasks` - A slice of tasks to create /// * `generator_task_id` - ID of the task that generated these tasks /// * `container_id` - The ID of the container to create the tasks in. If None, we will retrieve the parent container of the generator task and use that. pub async fn create_tasks( &self, tasks: &[Task], generator_id: TaskId, container_id: Option, ) -> Result { let mut tx = self.pool().begin().await?; // Collect all fields in a single pass for batch insert let mut ids = Vec::with_capacity(tasks.len()); let mut states = Vec::with_capacity(tasks.len()); let mut container_ids: Vec<_> = Vec::with_capacity(tasks.len()); let mut completion_policies = Vec::with_capacity(tasks.len()); let mut definitions = Vec::with_capacity(tasks.len()); // If container_id is None, look up the parent container of the generator task let container_id = match container_id { Some(id) => id, None => { debug!(generator_id = %generator_id, "target container_id not provided, looking up Generator parent container ID"); sqlx::query_scalar::<_, TaskContainerId>( "SELECT task_container_id FROM task_manager.tasks WHERE id = $1", ) .bind(generator_id) .fetch_one(&mut *tx) .await? } }; // Insert tasks for task in tasks { debug!(task_id = %task.id(), container_id = %container_id, kind = %task.definition(), "Preparing task for insertion"); ids.push(*task.id()); match task.state() { TaskState::Pending | TaskState::Disabled => states.push(task.state()), other => { return Err(Error::Other(format!( "Invalid task state for batch insert: {:?}. Only Pending or Disabled allowed.", other ))); } } container_ids.push(container_id); // clone CompletionPolicy to own it for batch insert completion_policies.push(task.completion_policy().clone()); definitions.push(serde_json::to_value(task.definition())?); // timeout copied from runtime task (associated const of definition) // We'll bind this as an array of int (seconds) } debug!(task_count = %ids.len(), container_id = %container_id, generator_id = %generator_id, "Inserting tasks"); // Convert enum vectors to text representations for Postgres enum array binding // let states_text: Vec<_> = states.iter().map(|s| s.to_string()).collect(); // let policies_text: Vec = // completion_policies.iter().map(|p| p.to_string()).collect(); sqlx::query( "INSERT INTO task_manager.tasks (id, state, task_container_id, completion_policy, definition, timeout) SELECT * FROM UNNEST($1::uuid[], $2::text[]::task_state[], $3::uuid[], $4::text[]::completion_policy[], $5::jsonb[], $6::int[])", ) .bind(&ids) .bind(&states) .bind(&container_ids) .bind(&completion_policies) .bind(&definitions) .bind(tasks.iter().map(|t| t.timeout_secs() as i32).collect::>()) .execute(&mut *tx) .await?; // Update generator_task_id task created_task_ids with new IDs if !ids.is_empty() { debug!(generator_task_id = %generator_id, created_task_count = %ids.len(), "Updating generator task with created task IDs"); let row = sqlx::query( r#" UPDATE task_manager.tasks SET created_task_ids = ( SELECT array( SELECT DISTINCT id FROM unnest(created_task_ids || $1::uuid[]) AS id ) ) WHERE id = $2 RETURNING array_length(created_task_ids, 1) AS created_count, expected_task_count "#, ) .bind(&ids) .bind(generator_id) .fetch_one(&mut *tx) .await?; let created_count: Option = row.try_get("created_count")?; let expected_task_count: Option = row.try_get("expected_task_count")?; debug!( generator_task_id = %generator_id, created_count, expected_task_count, "Updated generator task with created task IDs" ); } tx.commit().await?; Ok(()) } /// Finalize given task collections and store job history results in recon database. /// /// # Arguments /// * `task_collections` - A map of TaskContainerId to TaskState representing the root container id and the final state to set pub async fn finalize_task_collections( &self, task_collections: HashMap, ) -> Result { // Nothing to do if empty. if task_collections.is_empty() { return Ok(()); } let mut tx = self.pool().begin().await?; let completed_at = Utc::now(); // Collect root container ids & target states let root_ids: Vec = task_collections.keys().cloned().collect(); // 1) Delete all root containers in one statement (CASCADE via FK definitions will remove // descendant containers, tasks, task_data & task_collection_data). sqlx::query("DELETE FROM task_manager.task_containers WHERE id = ANY($1::uuid[])") .bind(&root_ids) .execute(&mut *tx) .await?; // 2) Bulk update latest job_history rows for each job to final state & completion time. // Build arrays of mapped JobState for bulk operation. let mut job_states: Vec = Vec::with_capacity(root_ids.len()); for rid in &root_ids { // We must look up the provided final TaskState from the hashmap (clone earlier state map) if let Some(ts) = task_collections.get(rid) { let js = match ts { TaskState::Pending => JobState::Pending, TaskState::Running | TaskState::PendingTaskCreation => JobState::Running, TaskState::PartiallyFailed => JobState::PartiallyFailed, TaskState::Succeeded => JobState::Succeeded, TaskState::Failed => JobState::Failed, TaskState::Disabled => JobState::Disabled, TaskState::Timeout => JobState::Timeout, }; job_states.push(js); } else { // Should not happen; defensive job_states.push(JobState::Failed); } } // Update only the latest row per job_id (ordered by started_at DESC) // Use DISTINCT ON within a CTE to isolate latest history IDs, then update join. let updated_job_ids: HashSet = sqlx::query_scalar( r#" WITH latest AS ( SELECT DISTINCT ON (job_id) id, job_id FROM recon.job_history WHERE job_id = ANY($1::uuid[]) ORDER BY job_id, started_at DESC ), new_states AS ( SELECT UNNEST($1::uuid[]) AS job_id, UNNEST($2::text[]::job_state[]) AS state ) UPDATE recon.job_history h SET state = new_states.state, completed_at = $3 FROM latest l JOIN new_states ON new_states.job_id = l.job_id WHERE h.id = l.id RETURNING h.job_id "#, ) .bind(&root_ids) .bind(&job_states) .bind(completed_at) .fetch_all(&mut *tx) .await? .into_iter() .collect(); // Validate we updated a history row for every provided root container id. if updated_job_ids.len() != root_ids.len() { error!( expected_count = %root_ids.len(), actual_count = %updated_job_ids.len(), "Not all job_history rows were updated for finalized task collections"); // Identify missing IDs for clarity let missing: Vec = root_ids .into_iter() .filter(|id| !updated_job_ids.contains(id)) .collect(); return Err(Error::Other(format!( "Missing job_history rows for finalized task collections: {:?}", missing ))); } tx.commit().await?; Ok(()) } /// Update the state of an existing task. /// /// # Arguments /// * `task_id` - The ID of the task to update /// * `state` - The new state to set for the task /// * `data` - Optional JSON data to associate with the task /// * `expected_task_count` - Optional expected task count to update. This is only applicable for tasks implementing the GeneratorTask trait. pub async fn update_task_state( &self, task_id: &TaskId, state: &TaskState, data: Option, expected_task_count: Option, ) -> Result { let mut tx = self.pool().begin().await?; // Capture a single timestamp for any started tasks (if transitioning to Running) let now = Utc::now(); // Find existing task and lock the row for update debug!(task_id = %task_id, "Validating task state before updating"); let existing_state: TaskState = sqlx::query_scalar( r#"SELECT state as "state: TaskState" FROM task_manager.tasks WHERE id = $1 FOR UPDATE"#, ) .bind(task_id) .fetch_one(&mut *tx) .await?; // Ensure the existing task state is not in a terminal state if matches!( existing_state, TaskState::Succeeded | TaskState::Failed | TaskState::PartiallyFailed ) { return Err("Cannot update task: already in a terminal state".into()); } // Update the task if let Some(expected_task_count) = expected_task_count { // Update the task state and expected_task_count debug!(task_id = %task_id, task_state = %state, expected_task_count = %expected_task_count, "Updating task state and expected_task_count"); if matches!(state, TaskState::Running) { sqlx::query( "UPDATE task_manager.tasks SET state = $1, expected_task_count = $2, started_at = COALESCE(started_at, $4) WHERE id = $3", ) .bind(state) .bind(expected_task_count) .bind(task_id) .bind(now) .execute(&mut *tx) .await?; } else { sqlx::query( "UPDATE task_manager.tasks SET state = $1, expected_task_count = $2 WHERE id = $3", ) .bind(state) .bind(expected_task_count) .bind(task_id) .execute(&mut *tx) .await?; } } else { // Update the task state debug!(task_id = %task_id, task_state = %state, "Updating task state"); if matches!(state, TaskState::Running) { sqlx::query("UPDATE task_manager.tasks SET state = $1, started_at = COALESCE(started_at, $3) WHERE id = $2") .bind(state) .bind(task_id) .bind(now) .execute(&mut *tx) .await?; } else { sqlx::query("UPDATE task_manager.tasks SET state = $1 WHERE id = $2") .bind(state) .bind(task_id) .execute(&mut *tx) .await?; } } // Insert or replace task data if provided if let Some(data) = data { let mut keys: Vec = Vec::with_capacity(data.len()); let mut values = Vec::with_capacity(data.len()); let mut kind: Vec = Vec::with_capacity(data.len()); for (key, value) in data.iter() { keys.push(key.clone()); let mut json_value = serde_json::to_value(value)?; values.push(json_value["data"].take()); kind.push( json_value["kind"] .as_str() .ok_or_else(|| Error::Other("data kind should be a string".to_string()))? .to_string(), ); } debug!(task_id = %task_id, data_keys = ?keys, data_values = ?values, data_kinds = ?kind, "Updating task data"); sqlx::query( r#" INSERT INTO task_manager.task_data (id, key, value, kind) SELECT $1, key, value, kind FROM UNNEST($2::text[], $3::jsonb[], $4::data_kind[]) AS t(key, value, kind) ON CONFLICT (id, key) DO UPDATE SET value = EXCLUDED.value, kind = EXCLUDED.kind "#, ) .bind(task_id) .bind(&keys) .bind(&values) .bind(&kind) .execute(&mut *tx) .await?; } tx.commit().await?; Ok(()) } /// Update multiple tasks to the Running state in a single transaction. /// All tasks must currently be in the Pending state. /// /// # Arguments /// * `task_ids` - A slice of TaskId references representing the tasks to update. pub async fn update_tasks_to_running(&self, task_ids: &[TaskId]) -> Result { let mut tx = self.pool().begin().await?; let now = Utc::now(); // Find existing tasks and lock the row for update debug!(task_count = %task_ids.len(), "Validating tasks state before updating to Running state"); let existing_states: Vec<(TaskId, TaskState)> = sqlx::query_as( r#"SELECT id, state as "state: TaskState" FROM task_manager.tasks WHERE id = ANY($1) FOR UPDATE"#, ) .bind(task_ids) .fetch_all(&mut *tx) .await?; // Ensure all existing task states are Pending if existing_states .iter() .any(|(_, state)| !matches!(state, TaskState::Pending)) { return Err("Cannot update tasks: all tasks must be in Pending state".into()); } // Update the state of all tasks to Running debug!(task_count = %task_ids.len(), "Updating tasks to Running state"); sqlx::query("UPDATE task_manager.tasks SET state = $1::task_state, started_at = COALESCE(started_at, $3) WHERE id = ANY($2)") .bind(TaskState::Running) .bind(task_ids) .bind(now) .execute(&mut *tx) .await?; tx.commit().await?; Ok(()) } /// Retrieve all tasks in the PendingTaskCreation state and update to Succeeded if /// the created_tasks_count matches expected_task_count. /// /// Returns the set of root container IDs for all tasks that were updated to Succeeded. /// This allows the caller to check these containers for new tasks to start. pub async fn complete_generator_tasks(&self) -> Result> { let mut tx = self.pool().begin().await?; debug!("Completing PendingTaskCreation generator tasks and returning root container IDs"); let rows = sqlx::query( r#" WITH updated AS ( UPDATE task_manager.tasks SET state = $1 WHERE state = $2 AND array_length(created_task_ids, 1) = expected_task_count RETURNING id, task_container_id ), roots AS ( SELECT u.id, -- Recursive CTE to walk up to the root container ( WITH RECURSIVE ct AS ( SELECT tc.id, tc.parent_id FROM task_manager.task_containers tc WHERE tc.id = u.task_container_id UNION ALL SELECT tc2.id, tc2.parent_id FROM task_manager.task_containers tc2 JOIN ct ON tc2.id = ct.parent_id ) SELECT id FROM ct WHERE parent_id IS NULL LIMIT 1 ) AS root_container_id FROM updated u ) SELECT DISTINCT root_container_id FROM roots WHERE root_container_id IS NOT NULL "#, ) .bind(TaskState::Succeeded) .bind(TaskState::PendingTaskCreation) .fetch_all(&mut *tx) .await?; tx.commit().await?; // Collect root_container_ids let mut root_ids = HashSet::new(); for row in rows { let root_id: Uuid = row.try_get("root_container_id")?; root_ids.insert(root_id); } Ok(root_ids) } /// Cancel all tasks that have timed out, setting their state to Timeout. /// /// Returns a unique set of root container IDs for all tasks that were cancelled due to timeout. pub async fn cancel_timed_out_tasks(&self) -> Result> { let mut tx = self.pool().begin().await?; debug!("Cancelling timed-out tasks"); // Update tasks whose timeout has expired and return their root container ids let rows = sqlx::query( r#" WITH updated AS ( UPDATE task_manager.tasks SET state = $1 WHERE state IN ($2, $3) AND started_at IS NOT NULL AND (started_at + make_interval(secs => timeout)) < now() RETURNING id, task_container_id ) SELECT DISTINCT ( WITH RECURSIVE ct AS ( SELECT tc.id, tc.parent_id FROM task_manager.task_containers tc WHERE tc.id = u.task_container_id UNION ALL SELECT tc2.id, tc2.parent_id FROM task_manager.task_containers tc2 JOIN ct ON tc2.id = ct.parent_id ) SELECT id FROM ct WHERE parent_id IS NULL LIMIT 1 ) AS root_container_id FROM updated u WHERE u.task_container_id IS NOT NULL "#, ) .bind(TaskState::Timeout) .bind(TaskState::Running) .bind(TaskState::PendingTaskCreation) .fetch_all(&mut *tx) .await?; let mut root_container_ids_with_cancelled_tasks = HashSet::new(); for row in rows { let root_id: Option = row.try_get("root_container_id")?; if let Some(rid) = root_id { root_container_ids_with_cancelled_tasks.insert(rid); } } tx.commit().await?; Ok(root_container_ids_with_cancelled_tasks) } } ``` ### `shared/mrpf_core/src/error/base.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/error/base.rs` ```rust use std::error::Error as StdError; use std::fmt; use std::io::Error as IoError; pub type Result = std::result::Result; /// Error type for the `mrpf_models` crate. /// /// Keep this lightweight; model code is often used in performance-sensitive contexts. /// Additional variants can be added as other models introduce error conditions. #[derive(Debug)] pub enum Error { #[cfg(feature = "sqlx")] Database(sqlx::Error), /// Errors related to merging model instances (store a compact human-readable description) ModelMergeError(String), /// Errors related to invalid execution context InvalidExecutionContext(String), // Runtime Expression parsing errors UnexpectedChar(char), UnclosedBrace, InvalidIdentifier, UnbalancedParen, EmptyExpression, UnclosedString, // Renderer function errors InvalidArguments, /// Other errors Other(String), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { #[cfg(feature = "sqlx")] Error::Database(err) => write!(f, "Database error: {}", err), Error::ModelMergeError(msg) => write!(f, "{msg}"), Error::InvalidExecutionContext(msg) => write!(f, "Invalid execution context: {msg}"), // Runtime Expression parsing errors Error::UnexpectedChar(c) => write!(f, "Unexpected character: {}", c), Error::UnclosedBrace => write!(f, "Unclosed brace"), Error::InvalidIdentifier => write!(f, "Invalid identifier"), Error::UnbalancedParen => write!(f, "Unbalanced parentheses"), Error::EmptyExpression => write!(f, "Empty expression"), Error::UnclosedString => write!(f, "Unclosed string"), // Renderer function errors Error::InvalidArguments => write!(f, "Invalid arguments"), Error::Other(msg) => write!(f, "{msg}"), } } } impl StdError for Error {} impl From for Error { fn from(e: IoError) -> Self { Error::Other(format!("IO error: {}", e)) } } impl From for Error { fn from(e: std::env::VarError) -> Self { Error::Other(format!("Environment variable error: {}", e)) } } impl From for Error { fn from(e: serde_json::Error) -> Self { Error::Other(format!("JSON error: {}", e)) } } impl From<&str> for Error { fn from(s: &str) -> Self { Error::Other(s.to_owned()) } } // sqlx feature conversions #[cfg(feature = "sqlx")] mod sqlx_feature { use super::Error; use sqlx::Error as SqlxError; impl From for Error { fn from(e: SqlxError) -> Self { Error::Database(e) } } impl From for sqlx::Error { fn from(e: Error) -> Self { match e { Error::Database(db_err) => db_err, Error::ModelMergeError(msg) => sqlx::Error::ColumnDecode { index: "model_merge".into(), source: Box::new(Error::ModelMergeError(msg)), }, Error::InvalidExecutionContext(msg) => sqlx::Error::ColumnDecode { index: "invalid_execution_context".into(), source: Box::new(Error::InvalidExecutionContext(msg)), }, // Catch any other type of Error and convert to a generic sqlx::Error _ => sqlx::Error::ColumnDecode { index: "other".into(), source: Box::new(e), }, } } } } // SQS feature conversions #[cfg(any(feature = "task_manager", feature = "task_execution"))] mod sqs_feature { use super::Error; use aws_sdk_sqs::error::SdkError; impl From> for Error { fn from(e: SdkError) -> Self { Error::Other(format!("SQS SendMessage error: {e}")) } } impl From> for Error { fn from(e: SdkError) -> Self { Error::Other(format!("SQS ReceiveMessage error: {e}")) } } } ``` ### `shared/mrpf_core/src/error/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/error/mod.rs` ```rust mod base; pub use base::{Error, Result}; ``` ### `shared/mrpf_core/src/lib.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/lib.rs` ```rust // Error module for MRPF models. mod error; pub use error::{Error, Result}; // The recon module contains models that can be interacted with directly through the MRPF API. pub mod recon; mod templates; pub use templates::{Template, Templates}; // The task_manager module contains models related to the task management service. // Note that we decouple this from the recon Job/JobHistory models intentionally to ensure the task_manager // stays as independent as possible from the recon service. #[cfg(feature = "task_shared")] pub mod tasks; #[cfg(feature = "sqlx")] pub mod db; #[cfg(feature = "sqlx")] pub use db::DbClient; ``` ### `shared/mrpf_core/src/recon/apn_device.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/apn_device.rs` ```rust use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; /// The kind of Apple Push Notification Device /// /// NOTE: we deviate from the normal snake casing here to align with Apple's naming conventions. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] #[cfg_attr(feature = "sqlx", sqlx(type_name = "apn_device_kind"))] #[expect(non_camel_case_types)] pub enum ApnDeviceKind { iPhone, iPad, MacBook, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] #[cfg_attr( feature = "sqlx", sqlx(type_name = "apn_device_env", rename_all = "snake_case") )] pub enum ApnDeviceEnv { Development, Production, } /// A Apple Push Notification Device struct /// /// This struct represents an Apple Push Notification Device (APN Device). #[derive(Debug, Serialize, Deserialize, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct ApnDevice { /// The id of the apn_device pub id: String, pub name: String, /// The ID of the target this apn_device is associated with pub kind: ApnDeviceKind, // Biggest thing to think about is that we'd want to have this in json in our task manager as we'd like to be able to filter there. // We don't nessecarily need to store all of this in the database though? pub env: ApnDeviceEnv, pub created_at: DateTime, } impl PartialEq for ApnDevice { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl Hash for ApnDevice { fn hash(&self, state: &mut H) { self.id.hash(state); } } impl PartialEq<&str> for ApnDevice { fn eq(&self, other: &&str) -> bool { &self.id == other } } impl ApnDevice { /// Create a new ApnDevice pub fn new( id: String, name: String, kind: ApnDeviceKind, env: ApnDeviceEnv, created_at: DateTime, ) -> Self { ApnDevice { id, name, kind, env, created_at, } } } // --------------------------------------------------------------------------- // sqlx FromRow implementation (feature-gated) for ApnDevice // --------------------------------------------------------------------------- #[cfg(feature = "sqlx")] mod sqlx_impl { use super::ApnDevice; use sqlx::Row; impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for ApnDevice { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { Ok(ApnDevice::new( row.try_get("id")?, row.try_get("name")?, row.try_get("kind")?, row.try_get("env")?, row.try_get("created_at")?, )) } } } ``` ### `shared/mrpf_core/src/recon/domain.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/domain.rs` ```rust use super::Ipv4Address; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::hash::{Hash, Hasher}; /// A Domain struct /// /// The definition of a domain is a single FQDN assoicated with a target and includes the /// current state of the domain in DNS. #[derive(Debug, Serialize, Deserialize, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct Domain { /// Fully Qualified Domain Name (FQDN) pub fqdn: String, /// The ID of the target this domain is associated with pub target_id: String, // Biggest thing to think about is that we'd want to have this in json in our task manager as we'd like to be able to filter there. // We don't nessecarily need to store all of this in the database though? #[serde(default)] pub is_wildcard: Option, /// List of technologies detected on the domain (e.g., web server, CMS, frameworks). #[serde(default)] pub technologies: Vec, /// Current active CNAME chain #[serde(default)] pub cnames: Vec, /// Current active IP addresses #[serde(default)] pub ip_addresses: HashSet, } impl PartialEq for Domain { fn eq(&self, other: &Self) -> bool { self.fqdn == other.fqdn } } impl Hash for Domain { fn hash(&self, state: &mut H) { self.fqdn.hash(state); } } impl PartialEq<&str> for Domain { fn eq(&self, other: &&str) -> bool { &self.fqdn == other } } impl Domain { /// Create a new Domain pub fn new( fqdn: String, target_id: String, is_wildcard: Option, cnames: Vec, ip_addresses: HashSet, ) -> Self { Domain { fqdn: fqdn.to_lowercase(), target_id, // is_active: None, // is_dangling: None, is_wildcard, technologies: vec![], cnames: cnames .into_iter() .map(|c| c.trim().trim_end_matches('.').to_lowercase()) .collect(), ip_addresses, } } /// Determine if the domain is active based on the presence of IP addresses. pub fn is_active(&self) -> bool { !self.ip_addresses.is_empty() } /// Determine if the domain is dangling based on the absence of IP addresses and presence of CNAMEs. pub fn is_dangling(&self) -> bool { self.ip_addresses.is_empty() && !self.cnames.is_empty() } /// Merge two domains /// /// This method merges the attributes of another `Domain` instance into the current instance. /// The merge strategy is as follows: /// /// * `target_id` - Replaced by the `target_id` of the other domain. /// * `is_active` - Replaced if the other domain's `is_active` is `Some`. /// * `is_dangling` - Replaced if the other domain's `is_dangling` is `Some`. /// * `is_wildcard` - Replaced if the other domain's `is_wildcard` is `Some`. /// * `technologies` - Extended with the other domain's `technologies`. /// * `cnames` - Extended with the other domain's `cnames`. /// * `ip_addresses` - Extended with the other domain's `ip_addresses`. /// /// # Arguments /// /// * `other` - The other `Domain` instance to merge into the current instance. pub fn merge(&mut self, other: Domain) { // Replace target id with the incoming one self.target_id = other.target_id; // Prefer the other's explicit wildcard flag when provided if other.is_wildcard.is_some() { self.is_wildcard = other.is_wildcard; } // Extend lists; dedup can be done at a higher level if needed if !other.technologies.is_empty() { self.technologies.extend(other.technologies); } if !other.cnames.is_empty() { self.cnames.extend(other.cnames); } // Merge ip addresses into the existing HashSet (union) self.ip_addresses.extend(other.ip_addresses); } } // --------------------------------------------------------------------------- // sqlx FromRow implementation (feature-gated) for Domain // --------------------------------------------------------------------------- #[cfg(feature = "sqlx")] mod sqlx_impl { use super::Domain; use super::Ipv4Address; use sqlx::Row; use sqlx::types::Json; use std::collections::HashSet; impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for Domain { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { Ok(Domain::new( row.try_get("id").or_else(|_| row.try_get("fqdn"))?, row.try_get("target_id")?, row.try_get("is_wildcard").ok(), // technologies: row // .try_get("technologies") // .map(|v: Vec| v) // .unwrap_or_else(|_| vec![]), row.try_get("cnames") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| vec![]), row.try_get::>, _>("ip_addresses") .map(|Json(v)| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()), )) } } } ``` ### `shared/mrpf_core/src/recon/ipv4address.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/ipv4address.rs` ```rust use crate::{Error, Result}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::fmt; use std::hash::{Hash, Hasher}; use std::net::Ipv4Addr; /// A Ipv4Address struct /// /// This struct represents an IPv4 address with associated metadata. #[derive(Debug, Serialize, Deserialize, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct Ipv4Address { /// The IPv4 address pub ip: Ipv4Addr, /// The target ID associated with this IP address pub target_id: String, /// TCP ports associated with this IP address #[serde(default)] pub tcp_ports: HashSet, /// UDP ports associated with this IP address #[serde(default)] pub udp_ports: HashSet, } impl PartialEq for Ipv4Address { fn eq(&self, other: &Self) -> bool { self.ip == other.ip } } impl Hash for Ipv4Address { fn hash(&self, state: &mut H) { self.ip.hash(state); } } impl PartialEq<&str> for Ipv4Address { fn eq(&self, other: &&str) -> bool { &self.ip.to_string() == other } } impl fmt::Display for Ipv4Address { /// Format the Ipv4Address including ports /// /// Example output: /// Ipv4Address(10.0.0.1, tcp_ports=[80, 443], udp_ports=[53]) fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Start with type and IP write!(f, "Ipv4Address({}", self.ip)?; // Collect optional parts let mut parts: Vec = Vec::with_capacity(2); if !self.tcp_ports.is_empty() { let mut tcp: Vec = self.tcp_ports.iter().copied().collect(); tcp.sort_unstable(); let tcp_str = tcp .iter() .map(|p| p.to_string()) .collect::>() .join(", "); parts.push(format!("tcp_ports=[{}]", tcp_str)); } if !self.udp_ports.is_empty() { let mut udp: Vec = self.udp_ports.iter().copied().collect(); udp.sort_unstable(); let udp_str = udp .iter() .map(|p| p.to_string()) .collect::>() .join(", "); parts.push(format!("udp_ports=[{}]", udp_str)); } if !parts.is_empty() { write!(f, ", {}", parts.join(", "))?; } write!(f, ")") } } impl Ipv4Address { /// Create a new Ipv4Address pub fn new( ip: Ipv4Addr, target_id: String, tcp_ports: HashSet, udp_ports: HashSet, ) -> Self { Ipv4Address { ip, target_id, tcp_ports, udp_ports, } } } impl Ipv4Address { /// Merge another Ipv4Address into this one. pub fn merge_from(&mut self, mut other: Ipv4Address) -> Result { if !self.target_id.is_empty() && !other.target_id.is_empty() && self.target_id != other.target_id { return Err(Error::ModelMergeError(format!( "Conflicting target_ids for {}: '{}' vs '{}'", self.ip, self.target_id, other.target_id ))); } if self.target_id.is_empty() && !other.target_id.is_empty() { self.target_id = std::mem::take(&mut other.target_id); } // Merge ports: ensure self.tcp_ports/self.udp_ports is the larger set to minimize rehashing if other.tcp_ports.len() > self.tcp_ports.len() { std::mem::swap(&mut self.tcp_ports, &mut other.tcp_ports); } self.tcp_ports.extend(other.tcp_ports); if other.udp_ports.len() > self.udp_ports.len() { std::mem::swap(&mut self.udp_ports, &mut other.udp_ports); } self.udp_ports.extend(other.udp_ports); Ok(()) } /// Merge an incoming address into a map keyed by IP. pub fn merge_into_map( map: &mut HashMap, incoming: Ipv4Address, ) -> Result { if let Some(existing) = map.get_mut(&incoming.ip) { existing.merge_from(incoming)?; } else { map.insert(incoming.ip, incoming); } Ok(()) } /// Merge a collection of IPv4Addresses into a deduplicated set keyed by IP. pub fn merge_set>(iter: I) -> Result> { let mut map: HashMap = HashMap::new(); for addr in iter { Self::merge_into_map(&mut map, addr)?; } Ok(map.into_values().collect()) } } // Implemented only when the "sqlx" feature is enabled. #[cfg(feature = "sqlx")] mod sqlx_impl { use super::Ipv4Address; use sqlx::Row; use std::collections::HashSet; use std::net::Ipv4Addr; // We implement FromRow for postgres rows (PgRow) and generic rows where possible. impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for Ipv4Address { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { // Read ip column as text then parse into IpAddr so we can detect v4 vs v6. let ip_str: String = row.try_get("ip")?; let ipv4_addr: Ipv4Addr = ip_str.parse().map_err(move |e| { sqlx::Error::Decode(Box::new(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("failed to parse ip '{}': {}", ip_str, e), ))) })?; let target_id: String = row.try_get("target_id")?; let tcp_ports: Vec = row.try_get("tcp_ports")?; let tcp_ports_u16: HashSet = tcp_ports.into_iter().map(|p| p as u16).collect(); let udp_ports: Vec = row.try_get("udp_ports")?; let udp_ports_u16: HashSet = udp_ports.into_iter().map(|p| p as u16).collect(); Ok(Ipv4Address::new( ipv4_addr, target_id, tcp_ports_u16, udp_ports_u16, )) } } } ``` ### `shared/mrpf_core/src/recon/job.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/job.rs` ```rust use super::JobHistory; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::fmt; use std::hash::{Hash, Hasher}; use std::str::FromStr; use tracing::error; use uuid::Uuid; #[cfg(feature = "task_manager")] use croner::Cron; /// The status of a job in our Recon database. /// /// This is based on the TaskState enum used in the task manager service, /// but adapted for whats strictly needed in a job context. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] #[cfg_attr( feature = "sqlx", sqlx(type_name = "job_state", rename_all = "snake_case") )] pub enum JobState { /// The task has never been run NeverRun, /// The task is not scheduled yet Pending, /// The task is currently running Running, /// The task has partially failed, meaning that some of the sub tasks have failed. /// This state can only be applied to TaskContainers PartiallyFailed, /// The task has succeeded Succeeded, /// The task has failed Failed, /// The task is disabled Disabled, /// The task has timed out Timeout, } #[derive(Serialize, Deserialize, Debug, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct Job { /// Unique identifier for the job. pub id: Uuid, /// Unique name of the job pub name: String, /// Optional target id pub target_id: Option, /// Job definition in JSON format pub definition: serde_json::Value, /// Crontab schedule pub schedule: String, /// Recurrence interval pub recurrence: Option, /// Latest run #[serde(default)] pub latest_run: Option, /// Timestamp when the collection was created (UTC) pub created_at: DateTime, } impl PartialEq for Job { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl Hash for Job { fn hash(&self, state: &mut H) { self.id.hash(state); } } impl PartialEq for Job { fn eq(&self, other: &Uuid) -> bool { self.id == *other } } impl Job { /// Create a new TaskCollection #[expect(clippy::too_many_arguments)] pub fn new( id: Uuid, name: String, target_id: Option, definition: serde_json::Value, schedule: String, recurrence: Option, latest_run: Option, created_at: DateTime, ) -> Self { Job { id, name, target_id, definition, schedule, recurrence, latest_run, created_at, } } /// Calculate the state of the job based on its latest run. /// /// If no runs have occurred, returns NotRun. pub fn state(&self) -> JobState { if let Some(history) = &self.latest_run { return history.state; } JobState::NeverRun } /// Returns true if the job is due to run based on its schedule /// and recurrence counter. #[cfg(feature = "task_manager")] pub fn is_due(&self, now: DateTime) -> (bool, Option>) { // Recurrence == 0 explicitly disables further runs. if let Some(0) = self.recurrence { return (false, None); } match Cron::from_str(&self.schedule) { Ok(cron_schedule) => { // Determine reference start let last_run_time = if let Some(history) = &self.latest_run { history.started_at } else { // Use created_at if the task has never run before self.created_at }; // Iterate forward for the next scheduled time strictly after last_run_time. let mut iter = cron_schedule.iter_after(last_run_time); if let Some(next_time) = iter.next() { // If next scheduled time is <= now, we're due. if next_time <= now { return (true, Some(next_time)); } else { return (false, Some(next_time)); } } } Err(_) => { error!(schedule = %self.schedule, "Invalid cron expression: treating as not due "); } } (false, None) } /// Generate a transformed copy of `self.definition` where symbolic UUID /// placeholders are replaced with freshly generated UUIDv7 values. /// /// Placeholder formats (must match the *entire* JSON string value): /// /// 1. `"UUID"` – always replaced by a newly generated **unique** UUIDv7. /// Each occurrence is independent. /// 2. `"UUID:SomeName"` – all occurrences sharing the same `SomeName` /// are replaced by the **same** UUIDv7 generated at the *first* /// encounter of that name during the traversal. /// /// The traversal walks the entire JSON tree (objects, arrays, nested /// structures). Non‑matching strings are left untouched. Only *exact* /// matches to the patterns above trigger replacement – substrings are /// ignored (e.g. `"foo-UUID"` is not changed). /// /// Additionally, the returned JSON object (top level only if it is an /// object) is augmented with a `task_collection_id` field containing this /// `TaskCollection`'s own `id` (unchanged). /// /// ### Example /// Input definition (placeholders shown): /// ```json /// { /// "collection": { /// "id": "UUID:Root", /// "children": [ /// { "task": { "id": "UUID", "depends_on": "UUID:Root" } }, /// { "task": { "id": "UUID", "parent": "UUID:Root" } }, /// { "container": { "id": "UUID:SharedContainer" } }, /// { "ref": "UUID:SharedContainer" } /// ] /// } /// } /// ``` /// /// Possible output (UUIDs will differ at runtime): /// ```json /// { /// "collection": { /// "id": "4c71c9c5-5f40-7b2e-9d9b-9d9a4d5c6f11", // UUID:Root /// "children": [ /// { "task": { "id": "6d5fd1aa-5f40-7c96-8a2c-c5e9d1f0b001", "depends_on": "4c71c9c5-5f40-7b2e-9d9b-9d9a4d5c6f11" } }, /// { "task": { "id": "a43e4c02-5f40-7f21-92d4-0f843f6d4002", "parent": "4c71c9c5-5f40-7b2e-9d9b-9d9a4d5c6f11" } }, /// { "container": { "id": "3b9f6f80-5f40-7e65-b3d1-6a52f98c5003" } }, // UUID:SharedContainer first occurrence /// { "ref": "3b9f6f80-5f40-7e65-b3d1-6a52f98c5003" } // UUID:SharedContainer reused /// ] /// }, /// "task_collection_id": "" /// } /// ``` /// Note how: /// * Each bare `UUID` became a distinct value. /// * All `UUID:Root` references share the same value. /// * All `UUID:SharedContainer` references share another value. /// * The extra `task_collection_id` field was injected at the top level. pub fn generate_definition_with_uuids(&self) -> serde_json::Value { use serde_json::Value; use std::collections::HashMap; fn walk(value: &mut Value, named: &mut HashMap) { match value { Value::String(s) => { if s == "UUID" { let new_id = Uuid::now_v7(); *s = new_id.to_string(); } else if let Some(rest) = s.strip_prefix("UUID:") && !rest.is_empty() { let entry = named.entry(rest.to_string()).or_insert_with(Uuid::now_v7); *s = entry.to_string(); } } Value::Array(arr) => { for v in arr.iter_mut() { walk(v, named); } } Value::Object(map) => { for (_k, v) in map.iter_mut() { walk(v, named); } } _ => {} } } let mut def_with_ids = self.definition.clone(); let mut named_cache: HashMap = HashMap::new(); // Pre-seed the named placeholder cache so all occurrences of // "UUID:Root" resolve to the actual job id. This lets the task // manager treat the root container id as the job id, enabling // finalize_task_collections to update recon.job_history rows using // the root container ids. named_cache.insert("Root".to_string(), self.id); walk(&mut def_with_ids, &mut named_cache); // Inject the task_collection_id into the definition JSON (top-level object only) if let Value::Object(ref mut map) = def_with_ids { map.insert( "task_collection_id".to_string(), Value::String(self.id.to_string()), ); } def_with_ids } } impl fmt::Display for Job { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "TaskCollection {{ {} }}", self.id)?; Ok(()) } } #[cfg(test)] #[allow(clippy::unwrap_used)] mod tests { use super::*; use serde_json::json; #[test] fn test_generate_definition_with_uuids_basic() { let definition = json!({ "root": { "id": "UUID:Root", "items": [ {"id": "UUID"}, {"id": "UUID"}, {"container": {"id": "UUID:Shared"}}, {"ref": "UUID:Shared"}, "UUID:Root" ] } }); let tc = Job::new( Uuid::now_v7(), "test".to_string(), None, definition, "0 * * * * *".to_string(), None, None, Utc::now(), ); let transformed = tc.generate_definition_with_uuids(); // Extract some fields to assert relationships let root_id = transformed["root"]["id"].as_str().unwrap(); let reused_root = transformed["root"]["items"][4].as_str().unwrap(); assert_eq!(root_id, reused_root, "Named root UUID should be reused"); let shared1 = transformed["root"]["items"][2]["container"]["id"] .as_str() .unwrap(); let shared2 = transformed["root"]["items"][3]["ref"].as_str().unwrap(); assert_eq!(shared1, shared2, "Named Shared UUID should be reused"); let first_uuid = transformed["root"]["items"][0]["id"].as_str().unwrap(); let second_uuid = transformed["root"]["items"][1]["id"].as_str().unwrap(); assert_ne!( first_uuid, second_uuid, "Bare UUID placeholders must be unique" ); // The injected field must exist assert!(transformed.get("task_collection_id").is_some()); // Ensure all replaced values parse as UUIDs for candidate in [ root_id, reused_root, shared1, shared2, first_uuid, second_uuid, ] { assert!(Uuid::parse_str(candidate).is_ok()); } } } // Manual FromRow implementation to convert SQL array/vector columns into HashSet fields // Implemented only when the "sqlx" feature is enabled. #[cfg(feature = "sqlx")] mod sqlx_impl { use super::{Job, JobHistory, JobState}; use chrono::{DateTime, Utc}; use sqlx::Row; use uuid::Uuid; // We implement FromRow for postgres rows (PgRow) and generic rows where possible. impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for Job { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { let id: Uuid = row.try_get("id")?; let name: String = row.try_get("name")?; let target_id: Option = row.try_get("target_id")?; let definition: serde_json::Value = row.try_get("definition")?; let schedule: String = row.try_get("schedule")?; let recurrence: Option = row.try_get("recurrence")?; // If the latest_history_id column exists we performed a query on the job_history table // to get the latest run information joined into this row. let latest_run = match row.try_get::, _>("latest_history_id") { Ok(h_id) => { // The column latest_history_id exists, but still might be NULL if it never run. if let Some(h_id) = h_id { let h_job_id: Uuid = row.try_get("latest_history_job_id")?; let h_state: JobState = row.try_get("latest_history_state")?; let h_stats: serde_json::Value = row.try_get("latest_history_stats")?; let h_started_at: DateTime = row.try_get("latest_history_started_at")?; let h_completed_at: Option> = row.try_get("latest_history_completed_at")?; Some(JobHistory::new( h_id, h_job_id, h_state, h_stats, h_started_at, h_completed_at, )) } else { None } } Err(sqlx::Error::ColumnNotFound(_)) => None, // Column absent: no latest run. Err(e) => return Err(e), }; let created_at: DateTime = row.try_get("created_at")?; Ok(Job::new( id, name, target_id, definition, schedule, recurrence, latest_run, created_at, )) } } } ``` ### `shared/mrpf_core/src/recon/job_history.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/job_history.rs` ```rust use super::JobState; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::fmt; use std::hash::{Hash, Hasher}; use uuid::Uuid; #[derive(Serialize, Deserialize, Debug, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct JobHistory { /// Unique identifier for the job. pub id: Uuid, /// reference to the job id pub job_id: Uuid, /// Status of the job run pub state: JobState, /// Statistics gathered during the run pub stats: serde_json::Value, /// Started_at timestamp pub started_at: DateTime, /// Finished_at timestamp #[serde(default)] pub completed_at: Option>, } impl PartialEq for JobHistory { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl Hash for JobHistory { fn hash(&self, state: &mut H) { self.id.hash(state); } } impl PartialEq for JobHistory { fn eq(&self, other: &Uuid) -> bool { self.id == *other } } impl JobHistory { /// Create a new TaskCollectionHistory pub fn new( id: Uuid, job_id: Uuid, state: JobState, stats: serde_json::Value, started_at: DateTime, completed_at: Option>, ) -> Self { JobHistory { id, job_id, state, stats, started_at, completed_at, } } } impl fmt::Display for JobHistory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "TaskCollectionHistory {{ {} }}", self.id)?; Ok(()) } } // Manual FromRow implementation to convert SQL array/vector columns into HashSet fields // Implemented only when the "sqlx" feature is enabled. #[cfg(feature = "sqlx")] mod sqlx_impl { use super::JobHistory; use super::JobState; use chrono::{DateTime, Utc}; use sqlx::Row; use uuid::Uuid; // We implement FromRow for postgres rows (PgRow) and generic rows where possible. impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for JobHistory { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { let id: Uuid = row.try_get("id")?; let job_id: Uuid = row.try_get("job_id")?; let state: JobState = row.try_get("state")?; let definition: serde_json::Value = row.try_get("stats")?; let started_at: DateTime = row.try_get("started_at")?; let completed_at: Option> = row.try_get("completed_at")?; Ok(JobHistory::new( id, job_id, state, definition, started_at, completed_at, )) } } } ``` ### `shared/mrpf_core/src/recon/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/mod.rs` ```rust mod apn_device; mod domain; mod ipv4address; mod job; mod job_history; mod notification; mod soa_record; mod target; mod template_key; mod word; mod word_category; mod wordlist; pub use apn_device::{ApnDevice, ApnDeviceEnv, ApnDeviceKind}; pub use domain::Domain; pub use ipv4address::Ipv4Address; pub use job::{Job, JobState}; pub use job_history::JobHistory; pub use notification::Notification; pub use notification::NotificationCategory; pub use soa_record::SoaRecord; pub use target::Target; pub use template_key::TemplateKey; pub use word::Word; pub use word_category::WordCategory; pub use wordlist::Wordlist; ``` ### `shared/mrpf_core/src/recon/notification.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/notification.rs` ```rust use super::Ipv4Address; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::hash::{Hash, Hasher}; use uuid::Uuid; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "sqlx", derive(sqlx::Type))] #[cfg_attr( feature = "sqlx", sqlx(type_name = "notification_category", rename_all = "snake_case") )] pub enum NotificationCategory { Generic, TaskCollectionResult, Alert, } impl std::fmt::Display for NotificationCategory { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { NotificationCategory::Generic => write!(f, "generic"), NotificationCategory::TaskCollectionResult => write!(f, "task_collection_result"), NotificationCategory::Alert => write!(f, "alert"), } } } impl std::fmt::Display for Notification { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(sub) = &self.subtitle { write!( f, "{} [{}] {}: {} - {}", self.id, self.category, self.title, sub, self.content ) } else { write!( f, "{} [{}] {} - {}", self.id, self.category, self.title, self.content ) } } } /// A Notification /// /// This struct represents a Notification sent to one or more devices. #[derive(Debug, Serialize, Deserialize, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct Notification { /// The id of the notification pub id: Uuid, /// The title of the notification pub title: String, /// The subtitle of the notification #[serde(default)] pub subtitle: Option, /// The content of the notification pub content: String, /// Whether the notification has been read #[serde(default)] pub is_read: bool, /// The category of the notification pub category: NotificationCategory, /// Target ID #[serde(default)] pub target_id: Option, /// Created at timestamp pub created_at: DateTime, } impl PartialEq for Notification { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl Hash for Notification { fn hash(&self, state: &mut H) { self.id.hash(state); } } impl PartialEq<&Uuid> for Notification { fn eq(&self, other: &&Uuid) -> bool { &self.id == *other } } impl Notification { /// Create a new Notification #[expect(clippy::too_many_arguments)] pub fn new( id: Uuid, title: String, subtitle: Option, content: String, is_read: bool, category: NotificationCategory, target_id: Option, created_at: DateTime, ) -> Self { Notification { id, title, subtitle, content, is_read, category, target_id, created_at, } } /// Create a new generic Notification pub fn new_generic( title: String, subtitle: Option, content: String, target_id: Option, ) -> Self { Notification { id: Uuid::now_v7(), title, subtitle, content, is_read: false, category: NotificationCategory::TaskCollectionResult, target_id, created_at: Utc::now(), } } /// Create a new error Notification pub fn new_error(error_message: String, target_id: Option) -> Self { Notification { id: Uuid::now_v7(), title: "Error Notification".to_string(), subtitle: None, content: error_message, is_read: false, category: NotificationCategory::Alert, target_id, created_at: Utc::now(), } } /// Create a new Task Collection Result Notification /// /// # Arguments /// * `new_fqdns` - An optional vector of new FQDNs found during the task. If None, the FQDNs were not scanned, if emtpy vector, no new FQDNs were found. /// * `new_ip_ports` - An optional vector of new IPv4 addresses with ports found during the task. If None, the IP ports were not scanned, if empty vector, no new IP ports were found. /// * `target_id` - An optional target ID associated with the notification. pub fn new_task_collection_result( new_fqdns: Option>, new_ipv4addresses: Option>, target_id: Option, ) -> Self { let title = if let Some(target_id) = &target_id { format!("{} - Results", target_id) } else { "Results".to_string() }; let fqdn_section = if let Some(fqdns) = &new_fqdns { if fqdns.is_empty() { "## FQDNs:\nNo new FQDNs found".to_string() } else { format!("## FQDNs: \n{}", { let mut fqdn_list: Vec<&str> = fqdns.iter().map(|s| s.as_str()).collect(); fqdn_list.sort_unstable(); fqdn_list.join(", ") }) } } else { "## FQDNs:\nNot scanned".to_string() }; let ipv4_address_section = if let Some(ipv4_addresses) = &new_ipv4addresses { if ipv4_addresses.is_empty() { "## IP Addresses:\nNo new ports found".to_string() } else { let ip_addresses: Vec = { let mut vec: Vec<&Ipv4Address> = ipv4_addresses.iter().collect(); vec.sort_unstable_by(|a, b| a.ip.cmp(&b.ip)); vec.into_iter().map(|ip| ip.to_string()).collect() }; format!("## IP Addresses: \n{}", ip_addresses.join(", ")) } } else { "## IP Addresses:\nNot scanned".to_string() }; let content = format!("{}\n{}", fqdn_section, ipv4_address_section); Notification { id: Uuid::now_v7(), title, subtitle: None, content, is_read: false, category: NotificationCategory::TaskCollectionResult, target_id, created_at: Utc::now(), } } /// Convert the Notification to a JSON string supported by Apple Push Notification service. /// /// https://developer.apple.com/documentation/usernotifications/generating-a-remote-notification pub fn to_apn_json(&self) -> serde_json::Value { let mut aps = serde_json::Map::new(); let alert = serde_json::json!({ "title": self.title, "subtitle": self.subtitle, "body": self.content, }); aps.insert("alert".to_string(), alert); aps.insert( "sound".to_string(), serde_json::Value::String("default".to_string()), ); aps.insert( self.category.to_string(), serde_json::Value::String("notification".to_string()), ); let mut payload = serde_json::Map::new(); payload.insert("aps".to_string(), serde_json::Value::Object(aps)); // payload.insert("badge".to_string(), serde_json::Value::Number(1.into())); // payload.insert("thread-id".to_string(), serde_json::Value::String(self.id.clone())); serde_json::Value::Object(payload) } } // --------------------------------------------------------------------------- // sqlx FromRow implementation (feature-gated) for Notification // --------------------------------------------------------------------------- #[cfg(feature = "sqlx")] mod sqlx_impl { use super::Notification; use sqlx::Row; impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for Notification { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { Ok(Notification::new( row.try_get("id")?, row.try_get("title")?, row.try_get("subtitle")?, row.try_get("content")?, row.try_get("is_read")?, row.try_get("category")?, row.try_get("target_id")?, row.try_get("created_at")?, )) } } } ``` ### `shared/mrpf_core/src/recon/soa_record.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/soa_record.rs` ```rust use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; /// A SoaRecord struct /// /// Represents a DNS Start of Authority (SOA) record. #[derive(Debug, Serialize, Deserialize, Eq)] #[serde(rename_all = "snake_case")] pub struct SoaRecord { /// The Owner Name of the SOA record, always the zone apex (e.g., example.com). pub owner_name: String, /// The Primary Nameserver for the zone. pub primary_ns: String, /// The hostmaster email for the zone, with the "@" replaced by a dot. pub hostmaster: String, } impl PartialEq for SoaRecord { fn eq(&self, other: &Self) -> bool { self.owner_name == other.owner_name } } impl Hash for SoaRecord { fn hash(&self, state: &mut H) { self.owner_name.hash(state); } } impl PartialEq<&str> for SoaRecord { fn eq(&self, other: &&str) -> bool { &self.owner_name == other } } impl SoaRecord { /// Create a new SoaRecord pub fn new(owner_name: String, primary_ns: String, hostmaster: String) -> Self { SoaRecord { owner_name: owner_name.trim().to_lowercase(), primary_ns: primary_ns.trim().to_lowercase(), hostmaster: hostmaster.trim().to_lowercase(), } } } // --------------------------------------------------------------------------- // sqlx FromRow implementation (feature-gated) for SoaRecord // --------------------------------------------------------------------------- #[cfg(feature = "sqlx")] mod sqlx_impl { use super::SoaRecord; use sqlx::Row; impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for SoaRecord { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { Ok(SoaRecord::new( row.try_get("owner_name")?, row.try_get("primary_ns")?, row.try_get("hostmaster")?, )) } } } ``` ### `shared/mrpf_core/src/recon/target.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/target.rs` ```rust use ipnet::Ipv4Net; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::fmt; use std::hash::{Hash, Hasher}; #[derive(Serialize, Deserialize, Debug, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct Target { /// Unique identifier for the target, typically a domain name or organization name. pub id: String, /// Top-Level Domains (TLDs) associated with the target. #[serde(default)] pub tlds: HashSet, /// Internal TLDs associated with the target. #[serde(default)] pub internal_tlds: HashSet, /// Certificate organization names associated with the target. #[serde(default)] pub cert_org_names: HashSet, /// Type to be decided, we might want to use my own Ipv4AddressRanges instead? Thats probably what we'll be using in most cases? #[serde(default)] pub ip_ranges: HashSet, /// Regex patterns for in-scope FQDNs. #[serde(default)] pub fqdns_in_scope: HashSet, /// Regex patterns for out-of-scope FQDNs. #[serde(default)] pub fqdns_out_of_scope: HashSet, /// Regex patterns for in-scope URLs. #[serde(default)] pub urls_in_scope: HashSet, /// Regex patterns for out-of-scope URLs. #[serde(default)] pub urls_out_of_scope: HashSet, /// Indicates if the target has any bounties. #[serde(default)] pub has_bounties: Option, /// Indicates if the target is private. #[serde(default)] pub is_private: Option, /// Indicates if the target is marked as a favorite. #[serde(default)] pub favorite: bool, } impl PartialEq for Target { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl Hash for Target { fn hash(&self, state: &mut H) { self.id.hash(state); } } impl PartialEq<&str> for Target { fn eq(&self, other: &&str) -> bool { &self.id == other } } impl Target { /// Create a new Target /// /// Sanitizes TLDs by lowercasing and trimming whitespace and trailing dots. #[expect(clippy::too_many_arguments)] pub fn new( id: String, tlds: HashSet, internal_tlds: HashSet, cert_org_names: HashSet, ip_ranges: HashSet, fqdns_in_scope: HashSet, fqdns_out_of_scope: HashSet, urls_in_scope: HashSet, urls_out_of_scope: HashSet, has_bounties: Option, is_private: Option, favorite: bool, ) -> Self { Target { id, tlds: tlds .into_iter() .map(|c| c.trim().trim_end_matches('.').to_lowercase()) .collect(), internal_tlds: internal_tlds .into_iter() .map(|c| c.trim().trim_end_matches('.').to_lowercase()) .collect(), cert_org_names, ip_ranges, fqdns_in_scope, fqdns_out_of_scope, urls_in_scope, urls_out_of_scope, has_bounties, is_private, favorite, } } /// Merges another `Target` into the current one, combining their attributes. /// /// This method extends the current target's lists with the lists from the other target. /// If the other target has non-None values for `has_bounties`, `is_private`, or `favorite`, /// these values will overwrite the current target's corresponding values. /// /// # Arguments /// /// * `other` - Another `Target` instance to merge into the current one. pub fn merge(&mut self, other: Target) { self.tlds.extend(other.tlds); self.internal_tlds.extend(other.internal_tlds); self.cert_org_names.extend(other.cert_org_names); self.ip_ranges.extend(other.ip_ranges); self.fqdns_in_scope.extend(other.fqdns_in_scope); self.fqdns_out_of_scope.extend(other.fqdns_out_of_scope); self.urls_in_scope.extend(other.urls_in_scope); self.urls_out_of_scope.extend(other.urls_out_of_scope); if other.has_bounties.is_some() { self.has_bounties = other.has_bounties; } if other.is_private.is_some() { self.is_private = other.is_private; } self.favorite = other.favorite; } } impl fmt::Display for Target { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Target {{ {} }}", self.id)?; Ok(()) } } // Manual FromRow implementation to convert SQL array/vector columns into HashSet fields // Implemented only when the "sqlx" feature is enabled. #[cfg(feature = "sqlx")] mod sqlx_impl { use std::collections::HashSet; use super::Target; use ipnet::Ipv4Net; use sqlx::Row; // We implement FromRow for postgres rows (PgRow) and generic rows where possible. impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for Target { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { let id: String = row.try_get("id")?; let tlds: HashSet = row .try_get("tlds") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()); let internal_tlds: HashSet = row .try_get("internal_tlds") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()); let cert_org_names: HashSet = row .try_get("cert_org_names") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()); let fqdns_in_scope: HashSet = row .try_get("fqdns_in_scope") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()); let fqdns_out_of_scope: HashSet = row .try_get("fqdns_out_of_scope") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()); let urls_in_scope: HashSet = row .try_get("urls_in_scope") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()); let urls_out_of_scope: HashSet = row .try_get("urls_out_of_scope") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()); // Somehow it's not converting from ipnet with the feature enabled. I want // to actually use my ipaddressrange anyway so leaving this out for now. // let ip_ranges_vec: Vec = row.try_get("ip_ranges")?; // let ip_ranges: std::collections::HashSet = ip_ranges_vec.into_iter().collect(); let ip_ranges: HashSet = HashSet::new(); // Optional scalar fields let has_bounties: Option = row.try_get("has_bounties").ok(); let is_private: Option = row.try_get("is_private").ok(); let favorite: bool = row.try_get("favorite").unwrap_or_default(); Ok(Target::new( id, tlds, internal_tlds, cert_org_names, ip_ranges, fqdns_in_scope, fqdns_out_of_scope, urls_in_scope, urls_out_of_scope, has_bounties, is_private, favorite, )) } } } ``` ### `shared/mrpf_core/src/recon/template_key.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/template_key.rs` ```rust use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; use uuid::Uuid; #[derive(Serialize, Deserialize, Debug, Eq)] #[serde(rename_all = "snake_case")] pub struct TemplateKey { id: Uuid, name: String, description: String, } impl PartialEq for TemplateKey { fn eq(&self, other: &Self) -> bool { self.name == other.name } } impl Hash for TemplateKey { fn hash(&self, state: &mut H) { self.name.hash(state); } } impl PartialEq<&str> for TemplateKey { fn eq(&self, other: &&str) -> bool { &self.name == other } } impl TemplateKey { /// Create a new TemplateKey pub fn new(id: Option, name: String, description: String) -> Self { TemplateKey { id: id.unwrap_or_else(Uuid::now_v7), name, description, } } pub fn id(&self) -> Uuid { self.id } pub fn name(&self) -> &str { &self.name } pub fn description(&self) -> &str { &self.description } } // --------------------------------------------------------------------------- // sqlx FromRow implementation (feature-gated) for Word // --------------------------------------------------------------------------- #[cfg(feature = "sqlx")] mod sqlx_impl { use super::TemplateKey; use sqlx::Row; use uuid::Uuid; impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for TemplateKey { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { let category_id: Uuid = row.try_get("id")?; Ok(TemplateKey::new( Some(category_id), row.try_get("name")?, row.try_get("description")?, )) } } } ``` ### `shared/mrpf_core/src/recon/word.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/word.rs` ```rust use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::hash::{Hash, Hasher}; use uuid::Uuid; use super::WordCategory; #[derive(Serialize, Deserialize, Debug, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct Word { /// Unique identifier for the word id: Uuid, /// Unique name for the word. Can contain template variables. name: String, /// Optional target ID that this word is associated with #[serde(default)] target_id: Option, /// Categories that this word belongs to #[serde(default)] categories: HashSet, /// Keys that this word can be used for in templates #[serde(default)] applicable_template_keys: HashSet, } impl PartialEq for Word { fn eq(&self, other: &Self) -> bool { if self.name == other.name && self.target_id == other.target_id { return true; } false } } impl Hash for Word { fn hash(&self, state: &mut H) { self.name.hash(state); self.target_id.hash(state); } } impl Word { /// Create a new Word pub fn new( id: Option, name: String, target_id: Option, categories: HashSet, applicable_template_keys: HashSet, ) -> Self { let id = id.unwrap_or_else(Uuid::now_v7); Word { id, name, target_id, categories, applicable_template_keys, } } pub fn extract_embedded_template_keys(&self) -> HashSet { todo!( "Word.name should become a Template and this should return the keys used in the templates" ) } pub fn id(&self) -> Uuid { self.id } pub fn name(&self) -> &str { &self.name } pub fn target_id(&self) -> &Option { &self.target_id } pub fn categories(&self) -> &HashSet { &self.categories } pub fn applicable_template_keys(&self) -> &HashSet { &self.applicable_template_keys } } // --------------------------------------------------------------------------- // sqlx FromRow implementation (feature-gated) for Word // --------------------------------------------------------------------------- #[cfg(feature = "sqlx")] mod sqlx_impl { use super::{Word, WordCategory}; use sqlx::Row; use std::collections::HashSet; use uuid::Uuid; impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for Word { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { let word_id: Uuid = row.try_get("id")?; let category_ids: Vec = row.try_get("category_ids")?; let category_names: Vec = row.try_get("category_names")?; let category_descriptions: Vec = row.try_get("category_descriptions")?; let categories: HashSet = category_ids .into_iter() .zip(category_names) .zip(category_descriptions) .map(|((id, name), description)| WordCategory::new(Some(id), name, description)) .collect(); // TODO: We need to define TemplateKey struct and use it here. For now lets just do the names let applicable_template_keys: HashSet = row .try_get("template_key_names") .map(|v: Vec| v.into_iter().collect()) .unwrap_or_else(|_| HashSet::new()); // let template_key_ids: Vec = row.try_get("template_key_ids")?; // let template_key_names: Vec = row.try_get("template_key_names")?; // let template_key_descriptions: Vec = // row.try_get("template_key_descriptions")?; // let template_keys: HashSet = template_key_ids // .into_iter() // .zip(template_key_names) // .zip(template_key_descriptions) // .map(|((id, name), description)| TemplateKey::new(id, name, description)) // .collect(); Ok(Word::new( Some(word_id), row.try_get("name")?, row.try_get("target_id")?, categories, applicable_template_keys, )) } } } ``` ### `shared/mrpf_core/src/recon/word_category.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/word_category.rs` ```rust use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; use uuid::Uuid; #[derive(Serialize, Deserialize, Debug, Eq, Clone)] #[serde(rename_all = "snake_case")] pub struct WordCategory { id: Uuid, name: String, description: String, } impl PartialEq for WordCategory { fn eq(&self, other: &Self) -> bool { self.name == other.name } } impl Hash for WordCategory { fn hash(&self, state: &mut H) { self.name.hash(state); } } impl PartialEq<&str> for WordCategory { fn eq(&self, other: &&str) -> bool { &self.name == other } } impl WordCategory { /// Create a new WordCategory pub fn new(id: Option, name: String, description: String) -> Self { WordCategory { id: id.unwrap_or_else(Uuid::now_v7), name, description, } } pub fn id(&self) -> Uuid { self.id } pub fn name(&self) -> &str { &self.name } pub fn description(&self) -> &str { &self.description } } // --------------------------------------------------------------------------- // sqlx FromRow implementation (feature-gated) for Word // --------------------------------------------------------------------------- #[cfg(feature = "sqlx")] mod sqlx_impl { use super::WordCategory; use sqlx::Row; use uuid::Uuid; impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for WordCategory { fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { let category_id: Uuid = row.try_get("id")?; Ok(WordCategory::new( Some(category_id), row.try_get("name")?, row.try_get("description")?, )) } } } ``` ### `shared/mrpf_core/src/recon/wordlist.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/recon/wordlist.rs` ```rust use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; #[derive(Serialize, Deserialize, Debug, Eq)] #[serde(rename_all = "snake_case")] pub struct Wordlist { id: String, templates: Vec, #[serde(default)] comments: Option, #[serde(default)] target_id: Option, } impl PartialEq for Wordlist { fn eq(&self, other: &Self) -> bool { self.id == other.id } } impl Hash for Wordlist { fn hash(&self, state: &mut H) { self.id.hash(state); } } impl PartialEq<&str> for Wordlist { fn eq(&self, other: &&str) -> bool { &self.id == other } } impl Wordlist { /// Create a new Wordlist pub fn new( id: String, templates: Vec, comments: Option, target_id: Option, ) -> Self { Wordlist { id, templates, comments, target_id, } } /// Merges another `Wordlist` into the current one, combining their attributes. /// /// This method merges the attributes of another `Wordlist` instance into the current instance. /// Templates from other will be merged into the current `templates`. /// If the other target has non-None values for `comments` or `target_id`, /// these values will overwrite the current target's corresponding values. /// /// # Arguments /// /// * `other` - Another `Wordlist` instance to merge into the current one. pub fn merge(&mut self, other: Wordlist) { self.templates.extend(other.templates); if other.comments.is_some() { self.comments = other.comments; } if other.target_id.is_some() { self.target_id = other.target_id; } } } ``` ### `shared/mrpf_core/src/tasks/definitions/conditional.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/conditional.rs` ```rust use crate::Result; use crate::tasks::models::{TaskDefinition, TaskDefinitionId, TaskId, WorkerRequirements}; use serde::{Deserialize, Serialize}; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct Conditional { id: TaskDefinitionId, // TODO: add parameters below } impl TaskDefinition for Conditional { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[]; const TIMEOUT_SECS: u32 = 30; // Placeholder short timeout } impl Conditional { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } // pub fn new() -> Self { // Self { // id: TaskDefinitionId::now_v7(), // } // } /// Recreate a Conditional from a definition on a runtime task. /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut _definition: serde_json::Value) -> Result { Ok(Self { id }) } #[cfg(feature = "task_execution")] async fn _execute(&mut self, _context: &ExecutionContext) -> Result { todo!("Implement conditional task execution") } } ``` ### `shared/mrpf_core/src/tasks/definitions/dns_scan_tasks/dns_scan.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/dns_scan_tasks/dns_scan.rs` ```rust use crate::Result; use crate::tasks::models::{ TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use crate::templates::DataKind; use mrpf_engine::network::dns::DnsQueryType; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use tracing::{debug, info}; #[cfg(feature = "task_execution")] use { crate::tasks::ExecutionContext, mrpf_dns_resolver::DnsResolver, mrpf_engine::ScannerConfig, mrpf_engine::utils::get_interface, }; /// A single DnsScan task. #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, Clone)] #[serde(tag = "dns_scan", rename_all = "snake_case")] pub struct DnsScan { /// The id of the task template. /// /// When loaded from a runtime task, it will use it's TaskId. When created as /// a reusable template, it will use TaskDefinitionId. id: TaskDefinitionId, /// The domain names to resolve domain_names: HashSet, /// The DNS query type to use for the scans. query_type: DnsQueryType, /// The target ID associated with this scan. target_id: String, /// Output key output: String, } impl TaskDefinition for DnsScan { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::TaskDataAccess, WorkerRequirements::RawSocketAccess, ]; const TIMEOUT_SECS: u32 = 900; // DNS scans moderate duration /// Execute the DNS Scan task. /// /// Returns a set of FQDNs and Certificate Organizations found during the scan. #[cfg(feature = "task_execution")] async fn execute(&mut self, context: &ExecutionContext) -> Result { let task_id = self.id().to_string(); info!( task_id = task_id, domain_names = ?self.domain_names, query_type = %self.query_type, "Executing DNS Scan" ); let interface = get_interface(None); let src_mac = interface.mac.expect("No MAC address found"); // Find the first available Ipv4 Address and use the first ip of the subnet as router_ip. let (router_ip, interface_ip) = interface .ips .iter() .find(|network| network.is_ipv4()) .map(|network| match network.ip() { std::net::IpAddr::V4(ipv4) => match network.network().to_canonical() { std::net::IpAddr::V4(network_id) => { let router_id = std::net::Ipv4Addr::from(network_id.to_bits() + 1); (router_id, ipv4) } _ => unreachable!(), }, _ => unreachable!(), }) .expect("No IPv4 address found"); debug!( task_id = task_id, // interface = %interface, src_mac = %src_mac, src_ip = %interface_ip, router_ip = %router_ip, "Building Scanner Config" ); let config = ScannerConfig::builder() .interface(interface) .src_mac(src_mac) .src_ip(interface_ip) .router_ip(router_ip) .build(); let scanner = DnsResolver::new(config); debug!(task_id = task_id, "Starting DNS Scanner"); let domain_names = std::mem::take(&mut self.domain_names); let dns_records = scanner.run(domain_names, self.query_type); if dns_records.is_empty() { info!(task_id = task_id, "No DNS records found during scan"); } else { info!( task_id = task_id, found_addresses = dns_records.len(), "Found DNS records during DNS Scan" ); debug!( task_id = task_id, "Discovered DNS records: {:?}", dns_records ); debug!( task_id = task_id, output_key = self.output, "Storing scan results" ); let data = DataKind::MapOfDnsRecords(dns_records); context .db()? .put_task_data(self.id(), &self.output, data) .await?; } info!(task_id = task_id, "DNS Scan completed successfully"); Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl DnsScan { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Initialize new DnsResolver task pub fn new( domain_names: HashSet, query_type: DnsQueryType, target_id: String, output: String, ) -> Self { Self { id: TaskDefinitionId::now_v7(), domain_names, query_type, target_id, output, } } /// Recreate a DnsResolver from a definition on a runtime task. /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let domain_names: HashSet = serde_json::from_value(definition["domain_names"].take())?; let query_type: DnsQueryType = serde_json::from_value(definition["query_type"].take())?; let target_id: String = serde_json::from_value(definition["target_id"].take())?; let output: String = serde_json::from_value(definition["output"].take())?; Ok(Self { id, domain_names, query_type, target_id, output, }) } } ``` ### `shared/mrpf_core/src/tasks/definitions/dns_scan_tasks/dns_scan_aggregator.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/dns_scan_tasks/dns_scan_aggregator.rs` ```rust use crate::Result; use crate::tasks::models::{ TaskContainerId, TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use serde::{Deserialize, Serialize}; use tracing::info; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct DnsScanAggregator { /// The id for the task id: TaskId, // The container id that contains the task to aggregate results from container_id_to_aggregate: TaskContainerId, /// The output variable of the task collection to store FQDN results in output_key: String, } impl DnsScanAggregator { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } pub fn new(container_id_to_aggregate: TaskContainerId, output_key: String) -> Self { Self { id: TaskDefinitionId::now_v7(), container_id_to_aggregate, output_key, } } /// Recreate a DnsScanAggregator from a definition on a runtime task. /// /// # Arguments /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let container_id_to_aggregate: TaskContainerId = serde_json::from_value(definition["container_id_to_aggregate"].take())?; let output_key: String = serde_json::from_value(definition["output_key"].take())?; Ok(Self { id, container_id_to_aggregate, output_key, }) } /// Aggregate the results from the tasks in the `container_id_to_aggregate` /// and store them in the `root_container_id` under the keys `fqdns` and `organizations`. /// /// # Arguments /// * `context` - The worker context containing the database client. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { info!(task_id = %self.id(), container_id_to_aggregate = %self.container_id_to_aggregate, "Executing DnsScanAggregator task"); let dns_record_count = context .db()? .aggregate_map_of_dns_records( &self.container_id_to_aggregate, &self.output_key, &self.output_key, ) .await?; info!( task_id = %self.id(), container_id_to_aggregate = %self.container_id_to_aggregate, dns_record_count = %dns_record_count, "Aggregated DNS Records results for DNS Scan tasks." ); Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TaskDefinition for DnsScanAggregator { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[WorkerRequirements::TaskDataAccess]; const TIMEOUT_SECS: u32 = 300; // Aggregators moderate duration } ``` ### `shared/mrpf_core/src/tasks/definitions/dns_scan_tasks/dns_scan_generator.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/dns_scan_tasks/dns_scan_generator.rs` ```rust use super::{DnsScan, DnsScanAggregator}; use crate::tasks::models::{ CompletionPolicy, Task, TaskContainerId, TaskDefinition, TaskId, TaskKind, TaskManagerQueueMessage, TaskResult, WorkerRequirements, }; use crate::{Error, Result, Templates}; use mrpf_engine::network::dns::DnsQueryType; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use tracing::debug; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; /// Generates distributedDnsScanResolver tasks #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, Clone)] #[serde(rename_all = "snake_case")] pub struct DnsScanGenerator { /// The ID of the task id: TaskId, /// The root_container_id of the task collection to retrieve data from root_container_id: TaskContainerId, /// The Tcp Syn task destination container ID destination_container_id: TaskContainerId, /// The Tcp Syn Aggregator task will be placed in this container aggregator_container_id: TaskContainerId, /// The domain names to scan domain_names: Templates, /// The DNS query type to use for the scans query_type: DnsQueryType, /// Target ID associated with this scan target_id: String, /// Maximum fqdns to scan per generatedDnsScanResolver task #[serde(default = "default_max_per_task")] max_per_task: usize, /// The output variable to store results in #[serde(default = "default_output_key")] output: String, // #[serde(default = "default_cert_org_output_key")] // Had this incorrectly. This is present on TLS scraper, not TCP SYN // cert_org_output_key: String, } fn default_max_per_task() -> usize { 100 } fn default_output_key() -> String { "dns_results".to_string() } impl TaskDefinition for DnsScanGenerator { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::QueueAccess, WorkerRequirements::TaskDataAccess, ]; const TIMEOUT_SECS: u32 = 120; // Generators expected to finish quickly /// Generate DnsScan tasks and send them to the queue. /// /// A typical DnsScan task structure will look as follows, however, we're able to /// use different structures as needed. /// /// - Sequential container /// - Parallel container /// -DnsScan 1 task /// -DnsScan 2 task /// - ... /// - DnsScanAggregator task #[cfg(feature = "task_execution")] async fn execute(&mut self, context: &ExecutionContext) -> Result { let mut expected_task_count = 0; let mut domain_names: HashSet = HashSet::new(); let (domain_name_refs, domain_name_refs_with_path) = self.domain_names.extract_references(); // Fetch all references from the database let ip_data_map = context .db()? .get_data_map( &self.root_container_id, domain_name_refs, domain_name_refs_with_path, ) .await?; debug!(data = ?ip_data_map, "Fetched data for domain_names"); // Parse all templates to get the actual domain names for template in &self.domain_names { debug!(template = %template, "domain_names template"); domain_names.insert(template.try_from(&ip_data_map)?); } debug!(ranges = ?domain_names, "Parsed domain names from templates"); // Generate parallel tasks // TODO: We need to decide how to split domain names. For now just a single task let dns = DnsScan::new( domain_names, self.query_type, self.target_id.clone(), self.output.clone(), ); let task = Task::new(CompletionPolicy::FailOnFailure, TaskKind::DnsScan(dns)); debug!(task_id = %task.id(), parent_id = %self.destination_container_id, generator_id = %self.id, "Sending 'Dns task' creation message to queue"); let message = TaskManagerQueueMessage::CreateTask { task, parent_id: Some(self.destination_container_id), generator_id: self.id, }; context.queue_client()?.send_mgr_message(&message).await?; expected_task_count += 1; // CreateDnsScanAggregator task let aggregator = DnsScanAggregator::new(self.destination_container_id, self.output.clone()); let task = Task::new( CompletionPolicy::FailOnFailure, TaskKind::DnsScanAggregator(aggregator), ); debug!(task_id = %task.id(), parent_id = %self.aggregator_container_id, generator_id = %self.id, "Sending 'DnsScanAggregator task' creation message to queue"); let message = TaskManagerQueueMessage::CreateTask { task, parent_id: Some(self.aggregator_container_id), generator_id: self.id, }; context.queue_client()?.send_mgr_message(&message).await?; expected_task_count += 1; debug!(expected_task_count = %expected_task_count, "DnsScanGenerator done generating tasks"); Ok(TaskResult::GeneratedTasks(expected_task_count)) } } impl DnsScanGenerator { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Recreate a DnsScanGenerator from a definition on a runtime task. /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let root_container_id: TaskContainerId = serde_json::from_value(definition["root_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'root_container_id' from definition: {}", e )) })?; let destination_container_id: TaskContainerId = serde_json::from_value(definition["destination_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'destination_container_id' from definition: {}", e )) })?; let aggregator_container_id: TaskContainerId = serde_json::from_value(definition["aggregator_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'aggregator_container_id' from definition: {}", e )) })?; let domain_names: Templates = serde_json::from_value(definition["domain_names"].take()) .map_err(|e| { Error::Other(format!( "Failed to parse 'domain_names' from definition: {}", e )) })?; let query_type: DnsQueryType = serde_json::from_value(definition["query_type"].take()) .map_err(|e| { Error::Other(format!( "Failed to parse 'query_type' from definition: {}", e )) })?; let target_id: String = serde_json::from_value(definition["target_id"].take())?; let max_per_task: usize = serde_json::from_value(definition["max_per_task"].take()) .unwrap_or(default_max_per_task()); let output: String = serde_json::from_value(definition["output"].take()).unwrap_or(default_output_key()); Ok(Self { id, root_container_id, destination_container_id, aggregator_container_id, domain_names, query_type, target_id, max_per_task, output, }) } } ``` ### `shared/mrpf_core/src/tasks/definitions/dns_scan_tasks/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/dns_scan_tasks/mod.rs` ```rust pub mod dns_scan; pub mod dns_scan_aggregator; pub mod dns_scan_generator; pub use dns_scan::DnsScan; pub use dns_scan_aggregator::DnsScanAggregator; pub use dns_scan_generator::DnsScanGenerator; ``` ### `shared/mrpf_core/src/tasks/definitions/filter.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/filter.rs` ```rust use crate::Result; use crate::tasks::models::{TaskDefinition, TaskDefinitionId, TaskId, WorkerRequirements}; use mrpf_matchers::{Matcher, WordMatcher}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::marker::PhantomData; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct Filter, T: ?Sized> { /// The id of the task template id: TaskDefinitionId, /// The data variable to extract from the job input_var: String, /// The variable to store the output data output_var: String, /// The task id containing the data to transform data_task_id: TaskId, /// The matcher to use for filtering matcher: M, /// This will ensure the compiler knows to use the correct type /// that is implemented for the matcher #[serde(skip)] _marker: PhantomData, } impl + Send + Sync, T: DeserializeOwned + Send + Sync> TaskDefinition for Filter { // fn id(&self) -> &TaskId { // &self.id // } const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[WorkerRequirements::TaskDataAccess]; const TIMEOUT_SECS: u32 = 60; // Generic filters expected to be fast } impl Filter where M: Matcher + for<'de> Deserialize<'de>, T: for<'de> Deserialize<'de>, { /// Create a new Filter task /// /// # Arguments /// /// * `input_var` - The data variable to extract from the job /// * `output_var` - The variable to store the output data. If not provided, it will default to the input_var /// which will replace the data in the database /// * `matcher` - The matcher to use for filtering pub fn new( task_id: TaskId, data_task_id: TaskId, input_var: String, output_var: Option, matcher: M, ) -> Self { // If no output_var is provided, use the input_var let output_var = output_var.unwrap_or_else(|| input_var.clone()); Self { id: task_id, data_task_id, input_var, output_var, matcher, _marker: PhantomData, } } /// Recreate a Filter from a database row /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let input_var = serde_json::from_value(definition["input_var"].take())?; let output_var = serde_json::from_value(definition["output_var"].take())?; let data_task_id = serde_json::from_value(definition["data_task_id"].take())?; let matcher = serde_json::from_value(definition["matcher"].take())?; Ok(Self { id, data_task_id, input_var, output_var, matcher, _marker: PhantomData, }) } #[cfg(feature = "task_execution")] async fn _execute(&mut self, _context: &ExecutionContext) -> Result { todo!("Fix the get task data typing"); // // Fetch the input data from the job // let input_data: Vec = db // .get_task_data(&self.data_task_id, &self.input_var) // .await? // .into(); // // Filter the data // let output_data: Vec = input_data // .into_iter() // .filter(|item| self.matcher.matches(item)) // .collect(); // // Store the output data in the job // db.replace_data(&self.output_var, output_data).await; // Ok(()) } } // Explicit implementation for WordMatcher and str // This is required because str is ?Sized impl TaskDefinition for Filter { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[WorkerRequirements::TaskDataAccess]; const TIMEOUT_SECS: u32 = 60; // Simple filter should be fast } impl Filter { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Create a new Filter task pub fn new( task_id: TaskId, data_task_id: TaskId, input_var: String, output_var: Option, matcher: WordMatcher, ) -> Self { // If no output_var is provided, use the input_var let output_var = output_var.unwrap_or_else(|| input_var.clone()); Self { id: task_id, data_task_id, input_var, output_var, matcher, _marker: PhantomData, } } /// Recreate a Filter from a database row /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let input_var = serde_json::from_value(definition["input_var"].take())?; let output_var = serde_json::from_value(definition["output_var"].take())?; let data_task_id = serde_json::from_value(definition["data_task_id"].take())?; let matcher = serde_json::from_value(definition["matcher"].take())?; Ok(Self { id, data_task_id, input_var, output_var, matcher, _marker: PhantomData, }) } /// Execute the filtering #[cfg(feature = "task_execution")] async fn _execute(&mut self, _context: &ExecutionContext) -> Result { // // Fetch the input data from the job // let input_data: Vec = db // .get_task_data(&self.data_task_id, &self.input_var) // .await?; // // Filter the data // let output_data: Vec = input_data // .into_iter() // .filter(|item| self.matcher.matches(item.as_str())) // .collect(); // // Store the output data in the job // db.replace_task_data(&self.data_task_id, &self.output_var, output_data.into()) // .await?; Ok(()) } } ``` ### `shared/mrpf_core/src/tasks/definitions/insert_data.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/insert_data.rs` ```rust use crate::Result; use crate::tasks::models::{ TaskContainerId, TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use crate::templates::DataKind; use serde::{Deserialize, Serialize}; use tracing::info; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; /// Task to insert data in a task collections data table #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct InsertData { /// The id for the task id: TaskId, // The root container id of the task collection to insert data into root_container_id: TaskContainerId, /// The output variable to store the results in key: String, /// The data to insert data: Option, } impl InsertData { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } pub fn new(root_container_id: TaskContainerId, key: String, data: Option) -> Self { Self { id: TaskDefinitionId::now_v7(), root_container_id, key, data, } } /// Recreate a InsertData from a definition on a runtime task. /// /// # Arguments /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let root_container_id: TaskContainerId = serde_json::from_value(definition["root_container_id"].take())?; let key: String = serde_json::from_value(definition["key"].take())?; let data: Option = serde_json::from_value(definition["data"].take())?; Ok(Self { id, root_container_id, key, data, }) } /// Execute the InsertData task using the provided database client. /// /// # Arguments /// * `context` - The worker context providing database access. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { if let Some(data) = self.data.take() { let data_kind_name = data.to_string(); info!(task_id = %self.id(), root_container_id = %self.root_container_id, key = %self.key, data_kind = %data_kind_name, "Executing InsertData task"); context .db()? .put_collection_data(&self.root_container_id, &self.key, data) .await?; } else { info!(task_id = %self.id(), root_container_id = %self.root_container_id, key = %self.key, "Executing InsertData task with no data to insert"); } Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TaskDefinition for InsertData { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[WorkerRequirements::TaskDataAccess]; const TIMEOUT_SECS: u32 = 60; // Simple insert should be fast } ``` ### `shared/mrpf_core/src/tasks/definitions/notifications/error.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/notifications/error.rs` ```rust use crate::Result; use crate::recon::Notification; use crate::tasks::models::{ TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; use serde::{Deserialize, Serialize}; use tracing::info; /// Send a notification #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct ErrorNotification { /// The id for the task id: TaskId, /// The error message error: String, /// Optional target id target_id: Option, } impl ErrorNotification { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Create a new ErrorNotification task. pub fn new(error: String, target_id: Option) -> Self { Self { id: TaskDefinitionId::now_v7(), error, target_id, } } /// Recreate a ErrorNotification from a definition on a runtime task. /// /// # Arguments /// /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { Ok(Self { id, error: serde_json::from_value(definition["error"].take())?, target_id: serde_json::from_value(definition["target_id"].take())?, }) } /// Execute the ErrorNotification task using the provided database client. /// /// # Arguments /// * `context` - The worker context providing database access. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { info!(task_id = %self.id(), error = %self.error, "Executing ErrorNotification task"); let notification = Notification::new_error(self.error.clone(), self.target_id.clone()); info!(task_id = %self.id(), notification = %notification, "Created error notification"); context.db()?.put_notification(¬ification).await?; Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TaskDefinition for ErrorNotification { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::TaskDataAccess, WorkerRequirements::ReconDataAccess, ]; const TIMEOUT_SECS: u32 = 60; // Recon API upserts moderate duration } ``` ### `shared/mrpf_core/src/tasks/definitions/notifications/generic.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/notifications/generic.rs` ```rust use crate::recon::Notification; use crate::tasks::models::{ TaskContainerId, TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use crate::{Result, Template}; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; use serde::{Deserialize, Serialize}; use tracing::{debug, info}; /// Send a notification #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct GenericNotification { /// The id for the task id: TaskId, /// The title of the notification title: String, /// Subtitle of the notification subtitle: Option, /// The root container id root_container_id: TaskContainerId, /// Content of the notification content: Template, /// Optional target id target_id: Option, } impl GenericNotification { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Create a new GenericNotification task. pub fn new( title: String, subtitle: Option, root_container_id: TaskContainerId, content: Template, target_id: Option, ) -> Self { Self { id: TaskDefinitionId::now_v7(), title, subtitle, root_container_id, content, target_id, } } /// Recreate a GenericNotification from a definition on a runtime task. /// /// # Arguments /// /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { Ok(Self { id, title: serde_json::from_value(definition["title"].take())?, subtitle: serde_json::from_value(definition["subtitle"].take())?, root_container_id: serde_json::from_value(definition["root_container_id"].take())?, content: serde_json::from_value(definition["content"].take())?, target_id: serde_json::from_value(definition["target_id"].take())?, }) } /// Execute the GenericNotification task using the provided database client. /// /// # Arguments /// * `context` - The worker context providing database access. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { info!(task_id = %self.id(), title = %self.title, "Executing GenericNotification task"); let (data_keys, data_keys_with_path) = self.content.extract_references(); debug!(task_id = %self.id(), ?data_keys, "Extracted data keys from template"); let data_map = context .db()? .get_data_map(&self.root_container_id, data_keys, data_keys_with_path) .await?; let content = self.content.render_to_string(&data_map, Some(", "))?; let notification = Notification::new_generic( self.title.clone(), self.subtitle.clone(), content, self.target_id.clone(), ); info!(task_id = %self.id(), notification = %notification, "Created generic notification"); context.db()?.put_notification(¬ification).await?; Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TaskDefinition for GenericNotification { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::TaskDataAccess, WorkerRequirements::ReconDataAccess, ]; const TIMEOUT_SECS: u32 = 60; // Recon API upserts moderate duration } ``` ### `shared/mrpf_core/src/tasks/definitions/notifications/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/notifications/mod.rs` ```rust mod error; mod generic; mod task_result; pub use error::ErrorNotification; pub use generic::GenericNotification; pub use task_result::TaskCollectionResultNotification; ``` ### `shared/mrpf_core/src/tasks/definitions/notifications/task_result.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/notifications/task_result.rs` ```rust use crate::Result; use crate::recon::Notification; use crate::tasks::models::{ TaskContainerId, TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use crate::templates::DataKind; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use tracing::{info, warn}; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; /// Send a Task Collection Result notification #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct TaskCollectionResultNotification { /// The id for the task id: TaskId, /// Root container id root_container_id: TaskContainerId, /// Optional target id target_id: Option, /// The task data key for any new FQDNs found new_fqdns_key: Option, /// The task data key for any new IP ports found new_ipv4_addresses: Option, } impl TaskCollectionResultNotification { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Create a new TaskResultNotification task. pub fn new( root_container_id: TaskContainerId, target_id: Option, new_fqdns_key: Option, new_ipv4_addresses: Option, ) -> Self { Self { root_container_id, id: TaskDefinitionId::now_v7(), target_id, new_fqdns_key, new_ipv4_addresses, } } /// Recreate a TaskCollectionResultNotification from a definition on a runtime task. /// /// # Arguments /// /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { Ok(Self { id, root_container_id: serde_json::from_value(definition["root_container_id"].take())?, target_id: serde_json::from_value(definition["target_id"].take())?, new_fqdns_key: serde_json::from_value(definition["new_fqdns_key"].take())?, new_ipv4_addresses: serde_json::from_value(definition["new_ipv4_addresses"].take())?, }) } /// Execute the TaskCollectionResultNotification task using the provided database client. /// /// # Arguments /// * `context` - The worker context providing database access. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { info!(task_id = %self.id(), "Executing TaskCollectionResultNotification task"); // Collect any new FQDNs let mut has_fqdns = false; let new_fqdns = if let Some(key) = &self.new_fqdns_key { match context.db()?.get_data(&self.root_container_id, key).await? { DataKind::SetOfFqdns(fqdns) => { has_fqdns = true; Some(fqdns) } DataKind::SetOfDomains(domains) => { let fqdns: HashSet = domains.into_iter().map(|domain| domain.fqdn).collect(); has_fqdns = true; Some(fqdns) } other => { warn!(task_id = %self.id(), key = %key, data_kind = %other.to_string(), "Unexpected data kind for new FQDNs key"); None } } } else { None }; // Collect any new IPv4 addresses let mut has_ipv4_addresses = false; let new_ipv4_addresses = if let Some(key) = &self.new_ipv4_addresses { match context.db()?.get_data(&self.root_container_id, key).await? { DataKind::SetOfIpv4Addresses(ipv4addresses) => { has_ipv4_addresses = true; Some(ipv4addresses) } other => { warn!(task_id = %self.id(), key = %key, data_kind = %other.to_string(), "Unexpected data kind for new IP ports key"); None } } } else { None }; // Send notification if any new data is found if !has_fqdns && !has_ipv4_addresses { info!(task_id = %self.id(), "No new FQDNs or IP addresses found to include in notification"); } else { let notification = Notification::new_task_collection_result( new_fqdns, new_ipv4_addresses, self.target_id.clone(), ); context.db()?.put_notification(¬ification).await?; info!(task_id = %self.id(), notification = %notification, "Created task collection result notification"); } Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TaskDefinition for TaskCollectionResultNotification { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::TaskDataAccess, WorkerRequirements::ReconDataAccess, ]; const TIMEOUT_SECS: u32 = 60; // Recon API upserts moderate duration } ``` ### `shared/mrpf_core/src/tasks/definitions/recon_api_tasks/list_models.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/recon_api_tasks/list_models.rs` ```rust use super::model_type::ModelType; use crate::tasks::models::{ TaskContainerId, TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use crate::templates::DataKind; use crate::{Error, Result}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tracing::{debug, info}; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; /// Get details for a given Recon API model id and store it in the data table of the task collection #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct ListModels { /// The id for the task id: TaskId, // The root container id of the task collection to insert data into root_container_id: TaskContainerId, /// The output variable to store the results in key: String, /// The model type to fetch model_type: ModelType, /// Filters to apply when fetching the model(s) /// /// For single model fetches (like Target or Domain), this should include an "id" field. /// For multi-model fetches (like Domains), this can include filters specific to the model type. filters: serde_json::Map, } impl ListModels { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Create a new ListModels task. pub fn new( root_container_id: TaskContainerId, key: String, model_type: ModelType, filters: serde_json::Map, ) -> Self { Self { id: TaskDefinitionId::now_v7(), root_container_id, key, model_type, filters, } } /// Recreate a ListModels from a definition on a runtime task. /// /// # Arguments /// /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let root_container_id: TaskContainerId = serde_json::from_value(definition["root_container_id"].take())?; let key: String = serde_json::from_value(definition["key"].take())?; let model_type: ModelType = serde_json::from_value(definition["model_type"].take())?; let filters: serde_json::Map = serde_json::from_value(definition["filters"].take())?; Ok(Self { id, root_container_id, key, model_type, filters, }) } /// Execute the ListModels task using the provided database client. /// /// # Arguments /// * `context` - The worker context providing database access. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { info!(task_id = %self.id(), root_container_id = %self.root_container_id, key = %self.key, "Executing ListModels task"); let data = match &self.model_type { ModelType::Target => { let target_id: String = self.get_filter_as("target_id")?; debug!(target_id = %target_id, "Fetching target from recon db"); let target = DataKind::Target(context.db()?.get_target(&target_id).await?); debug!(target_id = %target_id, "Fetched target from recon db"); target } ModelType::Ipv4Address => { let ip_address: String = self.get_filter_as("ip_address")?; debug!(ip_address = %ip_address, "Fetching IPv4 address from recon db"); let ipv4_address = DataKind::Ipv4Address(context.db()?.get_ip_address(&ip_address).await?); debug!(ip_address = ?ipv4_address, "Fetched IPv4 address from recon db"); ipv4_address } ModelType::Ipv4Addresses => { let target_id: Option = self.get_optional_filter_as("target_id")?; debug!(target_id = ?target_id, "Fetching IPv4 addresses from recon db"); let ipv4_addresses = DataKind::SetOfIpv4Addresses( context.db()?.list_ip_addresses(target_id.clone()).await?, ); if ipv4_addresses.is_empty() { return Err(Error::Other("No IPv4 addresses found".into())); } debug!(target_id = ?target_id, "Fetched IPv4 addresses from recon db"); ipv4_addresses } ModelType::Domain => { let fqdn: String = self.get_filter_as("fqdn")?; debug!(fqdn = %fqdn, "Fetching domain from recon db"); let domain = DataKind::Domain(context.db()?.get_domain(&fqdn).await?); debug!(fqdn = %fqdn, "Fetched domain from recon db"); domain } ModelType::Domains => { let target_id: String = self.get_filter_as("target_id")?; let is_active: Option = self.get_optional_filter_as("is_active")?; let is_wildcard: Option = self.get_optional_filter_as("is_wildcard")?; let is_dangling: Option = self.get_optional_filter_as("is_dangling")?; debug!( target_id = %target_id, is_active = ?is_active, is_wildcard = ?is_wildcard, is_dangling = ?is_dangling, "Fetching domains from recon db" ); let domains = DataKind::SetOfDomains( context .db()? .list_domains(target_id.clone(), is_active, is_wildcard, is_dangling) .await?, ); if domains.is_empty() { return Err(Error::Other("No domains found".into())); } debug!(target_id = %target_id, "Fetched domains from recon db"); domains } ModelType::Word => { let name: String = self.get_filter_as("name")?; let target_id: Option = self.get_optional_filter_as("target_id")?; debug!( word_name = %name, target_id = ?target_id, "Fetching word from recon db" ); let word = DataKind::Word(context.db()?.get_word_by_name(&name, target_id).await?); debug!(word_name = %name, "Fetched word from recon db"); word } ModelType::Words => { let target_id: Option = self.get_optional_filter_as("target_id")?; let included_category_names: Option> = self.get_optional_filter_as("included_category_names")?; let excluded_category_names: Option> = self.get_optional_filter_as("excluded_category_names")?; let included_category_names = included_category_names.unwrap_or_default(); let excluded_category_names = excluded_category_names.unwrap_or_default(); debug!( target_id = ?target_id, included_category_names = ?included_category_names, excluded_category_names = ?excluded_category_names, "Fetching words from recon db" ); let words = DataKind::SetOfWords( context .db()? .generate_wordlist( included_category_names, excluded_category_names, target_id, ) .await?, ); if words.is_empty() { return Err(Error::Other("No words found".into())); } debug!(word_count = ?words.len(), "Fetched words from recon db"); words } other => { return Err(crate::Error::Other(format!( "Unsupported model type '{}' in ListModels task", other ))); } }; context .db()? .put_collection_data(&self.root_container_id, &self.key, data) .await?; debug!(key = %self.key, "Stored target in task collection data"); Ok(TaskResult::Completed(TaskState::Succeeded)) } /// Extract a required filter and deserialize it to T. /// /// # Examples /// /// ``` /// let filter: String = self.get_filter_as("filter_key")?; /// ``` fn get_filter_as(&self, key: &str) -> Result { let v = self.filters.get(key).ok_or_else(|| { crate::Error::Other(format!("Missing '{}' filter for ListModels task", key)) })?; serde_json::from_value(v.clone()).map_err(|e| { crate::Error::Other(format!( "Invalid '{}' filter type for ListModels: {}", key, e )) }) } /// Extract an optional filter and deserialize it to Option fn get_optional_filter_as(&self, key: &str) -> Result> { match self.filters.get(key) { Some(v) => serde_json::from_value(v.clone()).map(Some).map_err(|e| { crate::Error::Other(format!( "Invalid '{}' filter type for ListModels: {}", key, e )) }), None => Ok(None), } } } impl TaskDefinition for ListModels { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::TaskDataAccess, WorkerRequirements::ReconDataAccess, ]; const TIMEOUT_SECS: u32 = 120; // Recon API queries moderate duration } ``` ### `shared/mrpf_core/src/tasks/definitions/recon_api_tasks/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/recon_api_tasks/mod.rs` ```rust pub mod list_models; pub mod upsert_models; pub use list_models::ListModels; pub use upsert_models::UpsertModels; pub mod model_type; ``` ### `shared/mrpf_core/src/tasks/definitions/recon_api_tasks/model_type.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/recon_api_tasks/model_type.rs` ```rust use serde::{Deserialize, Serialize}; use std::fmt; /// The types of models that can be fetched from the Recon API. /// /// These are used in our more generic model mutation tasks like ListModels and UpsertModels. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ModelType { Domain, Domains, Target, Targets, Word, Words, Ipv4Address, Ipv4Addresses, } impl fmt::Display for ModelType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self { ModelType::Domain => "domain", ModelType::Domains => "domains", ModelType::Target => "target", ModelType::Targets => "targets", ModelType::Word => "word", ModelType::Words => "words", ModelType::Ipv4Address => "ipv4_address", ModelType::Ipv4Addresses => "ipv4_addresses", }; f.write_str(s) } } ``` ### `shared/mrpf_core/src/tasks/definitions/recon_api_tasks/upsert_models.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/recon_api_tasks/upsert_models.rs` ```rust use crate::Result; use crate::tasks::models::{ TaskContainerId, TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use crate::templates::DataKind; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use tracing::{debug, info}; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; /// Retrieve data from the given task collection key and upsert the corresponding Recon API model(s) #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct UpsertModels { /// The id for the task id: TaskId, // The root container id of the task collection to insert data into root_container_id: TaskContainerId, /// The output variable to store the results in key: String, } impl UpsertModels { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Create a new UpsertModels task. pub fn new(root_container_id: TaskContainerId, key: String) -> Self { Self { id: TaskDefinitionId::now_v7(), root_container_id, key, } } /// Recreate a UpsertModels from a definition on a runtime task. /// /// # Arguments /// /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let root_container_id: TaskContainerId = serde_json::from_value(definition["root_container_id"].take())?; let key: String = serde_json::from_value(definition["key"].take())?; Ok(Self { id, root_container_id, key, }) } /// Execute the UpsertModels task using the provided database client. /// /// # Arguments /// * `context` - The worker context providing database access. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { info!(task_id = %self.id(), root_container_id = %self.root_container_id, key = %self.key, "Executing UpsertModels task"); let data = context .db()? .get_data(&self.root_container_id, &self.key) .await?; match data { DataKind::Target(target) => { context.db()?.put_target(&target).await?; } DataKind::Domain(domain) => { context.db()?.put_domains(HashSet::from([domain])).await?; } DataKind::SetOfDomains(domains) => { context.db()?.put_domains(domains).await?; } DataKind::SetOfIpv4Addresses(ipv4_addresses) => { context.db()?.put_ip_addresses(ipv4_addresses).await?; } other => { return Err(crate::Error::Other(format!( "Unsupported model type '{}' in UpsertModels task", other ))); } }; debug!(key = %self.key, "Stored target in task collection data"); Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TaskDefinition for UpsertModels { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::TaskDataAccess, WorkerRequirements::ReconDataAccess, ]; const TIMEOUT_SECS: u32 = 120; // Recon API upserts moderate duration } ``` ### `shared/mrpf_core/src/tasks/definitions/tcp_syn_tasks/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/tcp_syn_tasks/mod.rs` ```rust pub mod tcp_syn; pub mod tcp_syn_aggregator; pub mod tcp_syn_generator; pub use tcp_syn::TcpSyn; pub use tcp_syn_aggregator::TcpSynAggregator; pub use tcp_syn_generator::TcpSynGenerator; ``` ### `shared/mrpf_core/src/tasks/definitions/tcp_syn_tasks/tcp_syn.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/tcp_syn_tasks/tcp_syn.rs` ```rust use crate::tasks::models::{ TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use crate::Result; use crate::recon::Ipv4Address; use crate::templates::DataKind; use mrpf_engine::{Ipv4AddressRanges, PortRanges, ScannerConfig}; #[cfg(feature = "task_execution")] use mrpf_engine::utils::get_interface; #[cfg(feature = "task_execution")] use mrpf_tcp_syn_scanner::TcpSynScanner; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use tracing::{debug, info}; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; /// A single TcpSyn scan task. #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, Clone)] #[serde(tag = "tcp_syn", rename_all = "snake_case")] pub struct TcpSyn { /// The id of the task template. /// /// When loaded from a runtime task, it will use it's TaskId. When created as /// a reusable template, it will use TaskDefinitionId. id: TaskDefinitionId, /// The IPv4 addresses to scan. ipv4_address_ranges: Ipv4AddressRanges, /// The port_ranges to scan. port_ranges: PortRanges, /// The target ID associated with this scan. target_id: String, /// Output key output: String, } impl TaskDefinition for TcpSyn { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::TaskDataAccess, WorkerRequirements::RawSocketAccess, ]; const TIMEOUT_SECS: u32 = 900; // TCP SYN scans can take longer on large ranges /// Execute the TCP SYN scan task. /// /// Returns a set of IPv4Addresses during the scan. #[cfg(feature = "task_execution")] async fn execute(&mut self, context: &ExecutionContext) -> Result { let task_id = self.id().to_string(); info!( task_id = task_id, ip_ranges = %self.ipv4_address_ranges, port_ranges = %self.port_ranges, "Executing TCP SYN scan" ); let interface = get_interface(None); let src_mac = interface.mac.expect("No MAC address found"); // Find the first available Ipv4 Address and use the first ip of the subnet as router_ip. let (router_ip, interface_ip) = interface .ips .iter() .find(|network| network.is_ipv4()) .map(|network| match network.ip() { std::net::IpAddr::V4(ipv4) => match network.network().to_canonical() { std::net::IpAddr::V4(network_id) => { let router_id = std::net::Ipv4Addr::from(network_id.to_bits() + 1); (router_id, ipv4) } _ => unreachable!(), }, _ => unreachable!(), }) .expect("No IPv4 address found"); debug!( task_id = task_id, // interface = %interface, src_mac = %src_mac, src_ip = %interface_ip, router_ip = %router_ip, "Building Scanner Config" ); let config = ScannerConfig::builder() .interface(interface) .src_mac(src_mac) .src_ip(interface_ip) .router_ip(router_ip) .build(); let scanner = TcpSynScanner::new(config); debug!(task_id = task_id, "Starting TCP SYN scanner"); let scan_results = scanner.run(self.ipv4_address_ranges.clone(), self.port_ranges.clone()); let ipv4_addresses: HashSet = scan_results .into_iter() .map(|(ip, tcp_ports)| { Ipv4Address::new(ip, self.target_id.clone(), tcp_ports, HashSet::new()) }) .collect(); if ipv4_addresses.is_empty() { info!(task_id = task_id, "No open TCP ports found during scan"); } else { info!( task_id = task_id, found_addresses = ipv4_addresses.len(), "Found IPv4 addresses with open TCP ports" ); debug!( task_id = task_id, "Discovered IPv4 addresses: {:?}", ipv4_addresses ); debug!( task_id = task_id, output_key = self.output, "Storing scan results" ); let data = DataKind::SetOfIpv4Addresses(ipv4_addresses); context .db()? .put_task_data(self.id(), &self.output, data) .await?; } info!(task_id = task_id, "TCP SYN scan completed successfully"); Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TcpSyn { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Initialize new TcpSyn task pub fn new( ipv4_address_ranges: Ipv4AddressRanges, port_ranges: PortRanges, target_id: String, output: String, ) -> Self { Self { id: TaskDefinitionId::now_v7(), ipv4_address_ranges, port_ranges, target_id, output, } } /// Recreate a TcpSyn from a definition on a runtime task. /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let ipv4_address_ranges: Ipv4AddressRanges = serde_json::from_value(definition["ipv4_address_ranges"].take())?; let port_ranges: PortRanges = serde_json::from_value(definition["port_ranges"].take())?; let target_id: String = serde_json::from_value(definition["target_id"].take())?; let output: String = serde_json::from_value(definition["output"].take())?; Ok(Self { id, ipv4_address_ranges, port_ranges, target_id, output, }) } } ``` ### `shared/mrpf_core/src/tasks/definitions/tcp_syn_tasks/tcp_syn_aggregator.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/tcp_syn_tasks/tcp_syn_aggregator.rs` ```rust use crate::Result; use crate::tasks::models::{ TaskContainerId, TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use serde::{Deserialize, Serialize}; use tracing::info; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct TcpSynAggregator { /// The id for the task id: TaskId, // The container id that contains the task to aggregate results from container_id_to_aggregate: TaskContainerId, /// The output variable of the task collection to store FQDN results in output_key: String, } impl TcpSynAggregator { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } pub fn new(container_id_to_aggregate: TaskContainerId, output_key: String) -> Self { Self { id: TaskDefinitionId::now_v7(), container_id_to_aggregate, output_key, } } /// Recreate a TcpSynAggregator from a definition on a runtime task. /// /// # Arguments /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let container_id_to_aggregate: TaskContainerId = serde_json::from_value(definition["container_id_to_aggregate"].take())?; let output_key: String = serde_json::from_value(definition["output_key"].take())?; Ok(Self { id, container_id_to_aggregate, output_key, }) } /// Aggregate the results from the tasks in the `container_id_to_aggregate` /// and store them in the `root_container_id` under the keys `fqdns` and `organizations`. /// /// # Arguments /// * `context` - The worker context containing the database client. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { info!(task_id = %self.id(), container_id_to_aggregate = %self.container_id_to_aggregate, "Executing TcpSynAggregator task"); let ipv4_address_count = context .db()? .aggregate_set_of_ipv4_addresses( &self.container_id_to_aggregate, &self.output_key, &self.output_key, ) .await?; info!( task_id = %self.id(), container_id_to_aggregate = %self.container_id_to_aggregate, ipv4_address_count = %ipv4_address_count, "Aggregated IPv4 Addresses results for TCP SYN tasks." ); // let fqdn_count = context // .db()? // .aggregate_scalar_data( // &self.container_id_to_aggregate, // "fqdns", // self.fqdn_output_key.as_str(), // DataKind::SetOfFqdns(HashSet::new()), // TODO: this is all a bit wasteful, lets get it to work first.. // ) // .await?; // debug!(fqdn_count = %fqdn_count, "Aggregated FQDNs"); // let cert_org_count = context // .db()? // .aggregate_scalar_data( // &self.container_id_to_aggregate, // "organizations", // self.cert_org_output_key.as_str(), // DataKind::SetOfIpv4Addresses(HashSet::new()), // ) // .await?; // debug!(cert_org_count = %cert_org_count, "Aggregated Certificate Organizations"); Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TaskDefinition for TcpSynAggregator { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[WorkerRequirements::TaskDataAccess]; const TIMEOUT_SECS: u32 = 300; // Aggregators may need some time to combine results } ``` ### `shared/mrpf_core/src/tasks/definitions/tcp_syn_tasks/tcp_syn_generator.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/tcp_syn_tasks/tcp_syn_generator.rs` ```rust use crate::tasks::definitions::{TcpSyn, TcpSynAggregator}; use crate::tasks::models::{ CompletionPolicy, Task, TaskContainerId, TaskDefinition, TaskId, TaskKind, TaskManagerQueueMessage, TaskResult, WorkerRequirements, }; use crate::{Error, Result, Templates}; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use serde::{Deserialize, Serialize}; use tracing::debug; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; /// Generates distributed TcpSyn tasks #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, Clone)] #[serde(rename_all = "snake_case")] pub struct TcpSynGenerator { /// The ID of the task id: TaskId, /// The root_container_id of the task collection to retrieve data from root_container_id: TaskContainerId, /// The Tcp Syn task destination container ID destination_container_id: TaskContainerId, /// The Tcp Syn Aggregator task will be placed in this container aggregator_container_id: TaskContainerId, /// The IP Addresses to scan ipv4_address_ranges: Templates, /// The ports to scan port_ranges: Templates, /// Target ID associated with this scan target_id: String, /// Maximum fqdns to scan per generated TcpSyn task #[serde(default = "default_max_per_task")] max_per_task: usize, /// The output variable to store results in #[serde(default = "default_output_key")] output: String, // #[serde(default = "default_cert_org_output_key")] // Had this incorrectly. This is present on TLS scraper, not TCP SYN // cert_org_output_key: String, } fn default_max_per_task() -> usize { 100 } fn default_output_key() -> String { "ipv4addresses".to_string() } // fn default_cert_org_output_key() -> String { // "cert_organizations".to_string() // } impl TaskDefinition for TcpSynGenerator { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::QueueAccess, WorkerRequirements::TaskDataAccess, ]; const TIMEOUT_SECS: u32 = 120; // Generators expected to finish quickly /// Generate TCP Syn tasks and send them to the queue. /// /// A typical TCP SYN task structure will look as follows, however, we're able to /// use different structures as needed. /// /// - Sequential container /// - Parallel container /// - TcpSyn 1 task /// - TcpSyn 2 task /// - ... /// - TcpSynAggregator task #[cfg(feature = "task_execution")] async fn execute(&mut self, context: &ExecutionContext) -> Result { let mut expected_task_count = 0; let mut ip_addresses = Ipv4AddressRanges::default(); let mut port_ranges = PortRanges::default(); let (ip_refs, ip_refs_with_path) = self.ipv4_address_ranges.extract_references(); let (port_refs, port_refs_with_path) = self.port_ranges.extract_references(); // Fetch all references from the database let ip_data_map = context .db()? .get_data_map(&self.root_container_id, ip_refs, ip_refs_with_path) .await?; debug!(data = ?ip_data_map, "Fetched data for ipv4_address_ranges"); let port_data_map = context .db()? .get_data_map(&self.root_container_id, port_refs, port_refs_with_path) .await?; debug!(data = ?port_data_map, "Fetched data for port_ranges"); // Parse all templates to get the actual IP addresses and ports for template in &self.ipv4_address_ranges { debug!(template = %template, "ipv4_address_ranges template"); ip_addresses.add_ranges(template.try_from(&ip_data_map)?); } debug!(ranges = %ip_addresses, "Parsed Ipv4AddressRanges"); for template in &self.port_ranges { debug!(template = %template, "port_ranges template"); port_ranges.add_ranges(template.try_from(&port_data_map)?); } debug!(ranges = %port_ranges, "Parsed PortRanges"); // Generate parallel tasks // TODO: We have to be a bit more clever with splitting ranges. For now, we'll launch a task per ip and port range for ip_range in ip_addresses.ranges() { for port_range in port_ranges.ranges() { let tcp_syn = TcpSyn::new( Ipv4AddressRanges::new(vec![*ip_range]), PortRanges::new(vec![*port_range]), self.target_id.clone(), self.output.clone(), ); let task = Task::new(CompletionPolicy::FailOnFailure, TaskKind::TcpSyn(tcp_syn)); debug!(task_id = %task.id(), parent_id = %self.destination_container_id, generator_id = %self.id, "Sending 'TcpSyn task' creation message to queue"); let message = TaskManagerQueueMessage::CreateTask { task, parent_id: Some(self.destination_container_id), generator_id: self.id, }; context.queue_client()?.send_mgr_message(&message).await?; expected_task_count += 1; } } // Create TcpSynAggregator task let aggregator = TcpSynAggregator::new(self.destination_container_id, self.output.clone()); let task = Task::new( CompletionPolicy::FailOnFailure, TaskKind::TcpSynAggregator(aggregator), ); debug!(task_id = %task.id(), parent_id = %self.aggregator_container_id, generator_id = %self.id, "Sending 'TcpSynAggregator task' creation message to queue"); let message = TaskManagerQueueMessage::CreateTask { task, parent_id: Some(self.aggregator_container_id), generator_id: self.id, }; context.queue_client()?.send_mgr_message(&message).await?; expected_task_count += 1; debug!(expected_task_count = %expected_task_count, "TcpSynGenerator done generating tasks"); Ok(TaskResult::GeneratedTasks(expected_task_count)) } } impl TcpSynGenerator { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Recreate a TcpSynGenerator from a definition on a runtime task. /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let root_container_id: TaskContainerId = serde_json::from_value(definition["root_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'root_container_id' from definition: {}", e )) })?; let destination_container_id: TaskContainerId = serde_json::from_value(definition["destination_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'destination_container_id' from definition: {}", e )) })?; let aggregator_container_id: TaskContainerId = serde_json::from_value(definition["aggregator_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'aggregator_container_id' from definition: {}", e )) })?; let ipv4_address_ranges: Templates = serde_json::from_value(definition["ipv4_address_ranges"].take())?; let port_ranges: Templates = serde_json::from_value(definition["port_ranges"].take())?; let target_id: String = serde_json::from_value(definition["target_id"].take())?; let max_per_task: usize = serde_json::from_value(definition["max_per_task"].take()) .unwrap_or(default_max_per_task()); let output: String = serde_json::from_value(definition["output"].take()).unwrap_or(default_output_key()); Ok(Self { id, root_container_id, destination_container_id, aggregator_container_id, ipv4_address_ranges, port_ranges, target_id, max_per_task, output, }) } } ``` ### `shared/mrpf_core/src/tasks/definitions/tls_scan_tasks/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/tls_scan_tasks/mod.rs` ```rust pub mod tls_scan; pub mod tls_scan_aggregator; pub mod tls_scan_generator; pub use tls_scan::TlsScan; pub use tls_scan_aggregator::TlsScanAggregator; pub use tls_scan_generator::TlsScanGenerator; ``` ### `shared/mrpf_core/src/tasks/definitions/tls_scan_tasks/tls_scan.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/tls_scan_tasks/tls_scan.rs` ```rust use crate::Result; use crate::tasks::models::{ TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use crate::templates::DataKind; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use tracing::{debug, info}; #[cfg(feature = "task_execution")] use { crate::tasks::ExecutionContext, mrpf_engine::ScannerConfig, mrpf_engine::utils::get_interface, mrpf_tls_scanner::TlsScanner, }; /// A single TlsScan task. #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, Clone)] #[serde(tag = "tls_scan", rename_all = "snake_case")] pub struct TlsScan { /// The id of the task template. /// /// When loaded from a runtime task, it will use it's TaskId. When created as /// a reusable template, it will use TaskDefinitionId. id: TaskDefinitionId, /// The SNI to scan. sni: String, /// The IPv4 addresses to scan. ipv4_address_ranges: Ipv4AddressRanges, /// The port_ranges to scan. port_ranges: PortRanges, /// The target ID associated with this scan. target_id: String, /// Output key output: String, } impl TaskDefinition for TlsScan { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::TaskDataAccess, WorkerRequirements::RawSocketAccess, ]; const TIMEOUT_SECS: u32 = 900; // TLS scans may require more time /// Execute the TLS Scan task. /// /// Returns a set of FQDNs and Certificate Organizations found during the scan. #[cfg(feature = "task_execution")] async fn execute(&mut self, context: &ExecutionContext) -> Result { let task_id = self.id().to_string(); info!( task_id = task_id, ip_ranges = %self.ipv4_address_ranges, port_ranges = %self.port_ranges, "Executing TLS Scan" ); let interface = get_interface(None); let src_mac = interface.mac.expect("No MAC address found"); // Find the first available Ipv4 Address and use the first ip of the subnet as router_ip. let (router_ip, interface_ip) = interface .ips .iter() .find(|network| network.is_ipv4()) .map(|network| match network.ip() { std::net::IpAddr::V4(ipv4) => match network.network().to_canonical() { std::net::IpAddr::V4(network_id) => { let router_id = std::net::Ipv4Addr::from(network_id.to_bits() + 1); (router_id, ipv4) } _ => unreachable!(), }, _ => unreachable!(), }) .expect("No IPv4 address found"); debug!( task_id = task_id, // interface = %interface, src_mac = %src_mac, src_ip = %interface_ip, router_ip = %router_ip, "Building Scanner Config" ); let config = ScannerConfig::builder() .interface(interface) .src_mac(src_mac) .src_ip(interface_ip) .router_ip(router_ip) .build(); let scanner = TlsScanner::new(config); debug!(task_id = task_id, "Starting TLS Scanner"); let ipv4_address_ranges = std::mem::take(&mut self.ipv4_address_ranges); let port_ranges = std::mem::take(&mut self.port_ranges); let scan_targets = vec![(vec![self.sni.clone()], ipv4_address_ranges, port_ranges)]; let scan_results = scanner.run(scan_targets); // Returns a list of names and the Ip and port they were found on // We will discard the ip/port info for now and just store the unique FQDNs let fqdns_found: HashSet = scan_results.into_keys().collect(); if fqdns_found.is_empty() { info!(task_id = task_id, "No FQDNs found during scan"); } else { info!( task_id = task_id, found_addresses = fqdns_found.len(), "Found FQDNs during TLS scan" ); debug!(task_id = task_id, "Discovered FQDNs: {:?}", fqdns_found); debug!( task_id = task_id, output_key = self.output, "Storing scan results" ); let data = DataKind::SetOfFqdns(fqdns_found); context .db()? .put_task_data(self.id(), &self.output, data) .await?; } info!(task_id = task_id, "TLS Scan completed successfully"); Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TlsScan { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Initialize new TlsScan task pub fn new( sni: String, ipv4_address_ranges: Ipv4AddressRanges, port_ranges: PortRanges, target_id: String, output: String, ) -> Self { Self { id: TaskDefinitionId::now_v7(), sni, ipv4_address_ranges, port_ranges, target_id, output, } } /// Recreate a TlsScan from a definition on a runtime task. /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let sni: String = serde_json::from_value(definition["sni"].take())?; let ipv4_address_ranges: Ipv4AddressRanges = serde_json::from_value(definition["ipv4_address_ranges"].take())?; let port_ranges: PortRanges = serde_json::from_value(definition["port_ranges"].take())?; let target_id: String = serde_json::from_value(definition["target_id"].take())?; let output: String = serde_json::from_value(definition["output"].take())?; Ok(Self { id, sni, ipv4_address_ranges, port_ranges, target_id, output, }) } } ``` ### `shared/mrpf_core/src/tasks/definitions/tls_scan_tasks/tls_scan_aggregator.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/tls_scan_tasks/tls_scan_aggregator.rs` ```rust use crate::Result; use crate::tasks::models::{ TaskContainerId, TaskDefinition, TaskDefinitionId, TaskId, TaskResult, TaskState, WorkerRequirements, }; use serde::{Deserialize, Serialize}; use tracing::info; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, PartialEq, Clone)] #[serde(rename_all = "snake_case")] pub struct TlsScanAggregator { /// The id for the task id: TaskId, // The container id that contains the task to aggregate results from container_id_to_aggregate: TaskContainerId, /// The output variable of the task collection to store FQDN results in output_key: String, } impl TlsScanAggregator { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } pub fn new(container_id_to_aggregate: TaskContainerId, output_key: String) -> Self { Self { id: TaskDefinitionId::now_v7(), container_id_to_aggregate, output_key, } } /// Recreate a TlsScanAggregator from a definition on a runtime task. /// /// # Arguments /// * `id` - The TaskId of the runtime task. /// * `definition` - The serde_json::Value snapshot of the template. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let container_id_to_aggregate: TaskContainerId = serde_json::from_value(definition["container_id_to_aggregate"].take())?; let output_key: String = serde_json::from_value(definition["output_key"].take())?; Ok(Self { id, container_id_to_aggregate, output_key, }) } /// Aggregate the results from the tasks in the `container_id_to_aggregate` /// and store them in the `root_container_id` under the keys `fqdns` and `organizations`. /// /// # Arguments /// * `context` - The worker context containing the database client. #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { info!(task_id = %self.id(), container_id_to_aggregate = %self.container_id_to_aggregate, "Executing TlsScanAggregator task"); let ipv4_address_count = context .db()? .aggregate_set_of_ipv4_addresses( &self.container_id_to_aggregate, &self.output_key, &self.output_key, ) .await?; info!( task_id = %self.id(), container_id_to_aggregate = %self.container_id_to_aggregate, ipv4_address_count = %ipv4_address_count, "Aggregated IPv4 Addresses results for TCP SYN tasks." ); // let fqdn_count = context // .db()? // .aggregate_scalar_data( // &self.container_id_to_aggregate, // "fqdns", // self.fqdn_output_key.as_str(), // DataKind::SetOfFqdns(HashSet::new()), // TODO: this is all a bit wasteful, lets get it to work first.. // ) // .await?; // debug!(fqdn_count = %fqdn_count, "Aggregated FQDNs"); // let cert_org_count = context // .db()? // .aggregate_scalar_data( // &self.container_id_to_aggregate, // "organizations", // self.cert_org_output_key.as_str(), // DataKind::SetOfIpv4Addresses(HashSet::new()), // ) // .await?; // debug!(cert_org_count = %cert_org_count, "Aggregated Certificate Organizations"); Ok(TaskResult::Completed(TaskState::Succeeded)) } } impl TaskDefinition for TlsScanAggregator { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[WorkerRequirements::TaskDataAccess]; const TIMEOUT_SECS: u32 = 300; // Aggregators moderate duration } ``` ### `shared/mrpf_core/src/tasks/definitions/tls_scan_tasks/tls_scan_generator.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/definitions/tls_scan_tasks/tls_scan_generator.rs` ```rust use super::TlsScan; use crate::Templates; use crate::tasks::definitions::TlsScanAggregator; use crate::tasks::models::{ CompletionPolicy, Task, TaskContainerId, TaskDefinition, TaskId, TaskKind, TaskManagerQueueMessage, TaskResult, WorkerRequirements, }; use crate::{Error, Result}; use mrpf_engine::{Ipv4AddressRanges, PortRanges}; use serde::{Deserialize, Serialize}; use tracing::debug; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; /// Generates distributedTlsScan tasks #[derive(Serialize, Deserialize, Debug, sqlx::FromRow, Clone)] #[serde(rename_all = "snake_case")] pub struct TlsScanGenerator { /// The ID of the task id: TaskId, /// The root_container_id of the task collection to retrieve data from root_container_id: TaskContainerId, /// The Tcp Syn task destination container ID destination_container_id: TaskContainerId, /// The Tcp Syn Aggregator task will be placed in this container aggregator_container_id: TaskContainerId, /// The SNI's to scan snis: Templates, /// The IP Addresses to scan ipv4_address_ranges: Templates, /// The ports to scan port_ranges: Templates, /// Target ID associated with this scan target_id: String, /// Maximum fqdns to scan per generatedTlsScan task #[serde(default = "default_max_per_task")] max_per_task: usize, /// The output variable to store results in #[serde(default = "default_output_key")] output: String, // #[serde(default = "default_cert_org_output_key")] // Had this incorrectly. This is present on TLS scraper, not TCP SYN // cert_org_output_key: String, } fn default_max_per_task() -> usize { 100 } fn default_output_key() -> String { "fqdns".to_string() } // fn default_cert_org_output_key() -> String { // "cert_organizations".to_string() // } impl TaskDefinition for TlsScanGenerator { const WORKER_REQUIREMENTS: &'static [WorkerRequirements] = &[ WorkerRequirements::QueueAccess, WorkerRequirements::TaskDataAccess, ]; const TIMEOUT_SECS: u32 = 120; // Generators expected to finish quickly /// Generate TCP Syn tasks and send them to the queue. /// /// A typical TCP SYN task structure will look as follows, however, we're able to /// use different structures as needed. /// /// - Sequential container /// - Parallel container /// -TlsScan 1 task /// -TlsScan 2 task /// - ... /// - TlsScanAggregator task #[cfg(feature = "task_execution")] async fn execute(&mut self, context: &ExecutionContext) -> Result { let mut expected_task_count = 0; let mut ip_addresses = Ipv4AddressRanges::default(); let mut port_ranges = PortRanges::default(); let (ip_refs, ip_refs_with_path) = self.ipv4_address_ranges.extract_references(); let (port_refs, port_refs_with_path) = self.port_ranges.extract_references(); // Fetch all references from the database let ip_data_map = context .db()? .get_data_map(&self.root_container_id, ip_refs, ip_refs_with_path) .await?; debug!(data = ?ip_data_map, "Fetched data for ipv4_address_ranges"); let port_data_map = context .db()? .get_data_map(&self.root_container_id, port_refs, port_refs_with_path) .await?; debug!(data = ?port_data_map, "Fetched data for port_ranges"); // Parse all templates to get the actual IP addresses and ports for template in &self.ipv4_address_ranges { debug!(template = %template, "ipv4_address_ranges template"); ip_addresses.add_ranges(template.try_from(&ip_data_map)?); } debug!(ranges = %ip_addresses, "Parsed Ipv4AddressRanges"); for template in &self.port_ranges { debug!(template = %template, "port_ranges template"); port_ranges.add_ranges(template.try_from(&port_data_map)?); } debug!(ranges = %port_ranges, "Parsed PortRanges"); // Generate parallel tasks // TODO: We have to be a bit more clever with splitting ranges. For now, we'll launch a task per sni, ip and port range for ip_range in ip_addresses.ranges() { for port_range in port_ranges.ranges() { for sni in &self.snis { let tcp_syn = TlsScan::new( sni.to_string(), Ipv4AddressRanges::new(vec![*ip_range]), PortRanges::new(vec![*port_range]), self.target_id.clone(), self.output.clone(), ); let task = Task::new(CompletionPolicy::FailOnFailure, TaskKind::TlsScan(tcp_syn)); debug!(task_id = %task.id(), parent_id = %self.destination_container_id, generator_id = %self.id, "Sending 'TlsScan task' creation message to queue"); let message = TaskManagerQueueMessage::CreateTask { task, parent_id: Some(self.destination_container_id), generator_id: self.id, }; context.queue_client()?.send_mgr_message(&message).await?; expected_task_count += 1; } } } // CreateTlsScanAggregator task let aggregator = TlsScanAggregator::new(self.destination_container_id, self.output.clone()); let task = Task::new( CompletionPolicy::FailOnFailure, TaskKind::TlsScanAggregator(aggregator), ); debug!(task_id = %task.id(), parent_id = %self.aggregator_container_id, generator_id = %self.id, "Sending 'TlsScanAggregator task' creation message to queue"); let message = TaskManagerQueueMessage::CreateTask { task, parent_id: Some(self.aggregator_container_id), generator_id: self.id, }; context.queue_client()?.send_mgr_message(&message).await?; expected_task_count += 1; debug!(expected_task_count = %expected_task_count, "TlsScanGenerator done generating tasks"); Ok(TaskResult::GeneratedTasks(expected_task_count)) } } impl TlsScanGenerator { /// Return the runtime task's unique identifier. pub fn id(&self) -> &TaskId { &self.id } /// Recreate aTlsScanGenerator from a definition on a runtime task. /// /// The id here is the TaskId of the runtime task. pub fn from_definition(id: TaskId, mut definition: serde_json::Value) -> Result { let root_container_id: TaskContainerId = serde_json::from_value(definition["root_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'root_container_id' from definition: {}", e )) })?; let destination_container_id: TaskContainerId = serde_json::from_value(definition["destination_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'destination_container_id' from definition: {}", e )) })?; let aggregator_container_id: TaskContainerId = serde_json::from_value(definition["aggregator_container_id"].take()).map_err(|e| { Error::Other(format!( "Failed to parse 'aggregator_container_id' from definition: {}", e )) })?; let snis: Templates = serde_json::from_value(definition["snis"].take())?; let ipv4_address_ranges: Templates = serde_json::from_value(definition["ipv4_address_ranges"].take())?; let port_ranges: Templates = serde_json::from_value(definition["port_ranges"].take())?; let target_id: String = serde_json::from_value(definition["target_id"].take())?; let max_per_task: usize = serde_json::from_value(definition["max_per_task"].take()) .unwrap_or(default_max_per_task()); let output: String = serde_json::from_value(definition["output"].take()).unwrap_or(default_output_key()); Ok(Self { id, root_container_id, destination_container_id, aggregator_container_id, snis, ipv4_address_ranges, port_ranges, target_id, max_per_task, output, }) } } ``` ### `shared/mrpf_core/src/tasks/execution_context.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/execution_context.rs` ```rust use crate::DbClient; use crate::tasks::QueueClient; use crate::{Error, Result}; /// Collection of client handles a worker may use. /// /// Based on the Task WorkerRequirements, we can use the builder to generate the /// final WorkerClients struct with only the clients we need. /// /// Example: /// ```ignore /// let clients = WorkerClients::builder() /// .db(db_client) /// .db(db_client) /// .queue(queue_client) /// .build(); /// ``` /// TODO: RENAME THIS TO EXECUTION_CONTEXT ? pub struct ExecutionContext { db: Option, queue_client: Option, } impl ExecutionContext { /// Existing constructor retained for compatibility. Prefer using the builder. pub fn new(db: Option, queue_client: Option) -> Self { Self { db, queue_client } } /// Start a new builder instance. pub fn builder() -> ExecutionContextBuilder { ExecutionContextBuilder::default() } pub fn db(&self) -> Result<&DbClient> { self.db .as_ref() .ok_or_else(|| Error::InvalidExecutionContext("Task DB client is not available".into())) } pub fn queue_client(&self) -> Result<&QueueClient> { self.queue_client .as_ref() .ok_or_else(|| Error::InvalidExecutionContext("Queue client is not available".into())) } } /// Manual builder (no external crate) for `WorkerClients`. #[derive(Default)] pub struct ExecutionContextBuilder { db: Option, queue_client: Option, } impl ExecutionContextBuilder { /// Provide a task DB client. pub fn db(mut self, client: DbClient) -> Self { self.db = Some(client); self } /// Provide a queue client. pub fn queue(mut self, client: QueueClient) -> Self { self.queue_client = Some(client); self } /// Finish building. pub fn build(self) -> ExecutionContext { ExecutionContext { db: self.db, queue_client: self.queue_client, } } } ``` ### `shared/mrpf_core/src/tasks/manager/handler.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/manager/handler.rs` ```rust use crate::tasks::QueueClient; use crate::tasks::models::{ Task, TaskCollection, TaskContainerId, TaskId, TaskManagerQueueMessage, TaskState, WorkerQueueMessage, }; use crate::templates::DataMap; use crate::{DbClient, Error, Result}; use aws_sdk_ec2::types::InstanceType; use futures::future::join_all; use std::collections::{HashMap, HashSet}; use tracing::{debug, error, info, warn}; const MAX_BARE_METAL_WORKERS: i32 = 5; const BARE_METAL_WORKER_LAUNCH_TEMPLATE: &str = "mathijs-worker-ami"; /// The Manager struct holds the database and queue clients for managing tasks. /// /// We typically initialize the manager in AWS Lambda during cold start and pass /// it along to the handler function to avoid re-initializing clients on every invocation. /// /// This does mean that it's important that the Manager doesn't pass mutable references (not thread safe) or /// holds any state that changes between invocations, as AWS Lambda may reuse the same instance /// for multiple invocations. pub struct Manager { db_client: DbClient, queue_client: QueueClient, } impl Manager { /// Create a new Manager instance with the provided database and queue clients. pub fn new(db_client: DbClient, queue_client: QueueClient) -> Self { Self { db_client, queue_client, } } /// The main handler function for processing tasks. pub async fn run(&self, messages: Vec) -> Result { let mut root_containers_to_check_for_new_tasks = HashSet::new(); let (new_task_collections, tasks_to_create, completed_task_futures) = self.parse_messages(messages).await?; // Process any potential timeouts root_containers_to_check_for_new_tasks.extend(self.process_task_timeouts().await?); // Create new task collections if !new_task_collections.is_empty() { root_containers_to_check_for_new_tasks .extend(self.create_task_collections(new_task_collections).await?); } // Create new tasks from Generators. if !tasks_to_create.is_empty() { self.create_tasks_from_generators(tasks_to_create).await?; } // Process all task completions. // // This action should be performed after processing all our other messages to increase the // likelihood we're dealing with the most up-to-date state of the world. if !completed_task_futures.is_empty() { root_containers_to_check_for_new_tasks.extend( self.process_task_completions(completed_task_futures) .await?, ); } // Update all generator tasks to Succeeded if they have no more tasks to generate. root_containers_to_check_for_new_tasks .extend(self.db_client.complete_generator_tasks().await?); // Check all root containers that had activity (new tasks created, timeouts or tasks completed) for new tasks to start. let mut completed_task_collections = HashMap::new(); if !root_containers_to_check_for_new_tasks.is_empty() { let task_collection_state = self .schedule_next_tasks(root_containers_to_check_for_new_tasks) .await?; for (root_container_id, state) in task_collection_state { info!(root_container_id = %root_container_id, state = ?state, "Task collection state after processing messages"); if state.is_completed() { completed_task_collections.insert(root_container_id, state); } } } // Finalize any task collections that are fully completed/failed. self.db_client .finalize_task_collections(completed_task_collections) .await?; // Start any new due task collections self.start_due_task_collections().await?; Ok(()) } /// Process TaskManagerQueueMessages and return the results for further handling. pub async fn parse_messages( &self, messages: Vec, ) -> Result<( Vec, HashMap<(Option, TaskId), Vec>, Vec>>, )> { let mut new_task_collections: Vec = Vec::new(); let mut tasks_to_create: HashMap<(Option, TaskId), Vec> = HashMap::new(); let mut completed_task_futures = Vec::new(); // Parse all messages and categorize them. This allows us // to batch things together in some cases, reducing the round-trip-time // towards the database. for message in messages { match message { TaskManagerQueueMessage::CreateTaskCollection { collection } => { debug!( root_container_id = %collection.root_container_id(), "Received CreateTaskCollection message" ); new_task_collections.push(collection); } TaskManagerQueueMessage::CreateTask { task, parent_id, generator_id, } => { debug!(task_id = %task.id(), generator_task_id = %generator_id, "Received CreateTask message"); tasks_to_create .entry((parent_id, generator_id)) .or_default() .push(task); } TaskManagerQueueMessage::CompleteTask { task_id, state, root_container_id, data, expected_task_count, } => { debug!(task_id = %task_id, state = %state, root_container_id = %root_container_id, "Received CompleteTask message"); completed_task_futures.push(self.complete_task( task_id, state, root_container_id, data, expected_task_count, )); } } } Ok(( new_task_collections, tasks_to_create, completed_task_futures, )) } /// Process all completed tasks by awaiting their futures and collecting the root /// container IDs that need to be checked for new tasks. async fn process_task_completions( &self, completed_task_futures: Vec>>, ) -> Result> { debug!(completed_task_count = %completed_task_futures.len(), "Processing completed tasks"); let mut root_containers_to_check_for_new_tasks = HashSet::new(); let results = join_all(completed_task_futures).await; // let mut root_container_ids_with_failed_tasks = HashSet::new(); for result in results { match result { Ok(root_container_id) => { root_containers_to_check_for_new_tasks.insert(root_container_id); } Err(e) => { error!("Error completing task: {:?}", e); todo!( "Create custom error type that captures the root_container_id of the failed task." ); // root_container_ids_with_failed_tasks.insert(root_container_id); } } } Ok(root_containers_to_check_for_new_tasks) } /// Hande task completion by updating the task state in the database. /// /// # Arguments /// * `task_id` - ID of the task to complete. /// * `state` - New state of the task. /// * `root_container_id` - ID of the root container the task belongs to. /// * `data` - Optional data map to associate with the task. /// * `expected_task_count` - Optional expected task count for the task. /// /// Returns the root container ID upon successful completion. async fn complete_task( &self, task_id: TaskId, state: TaskState, root_container_id: TaskContainerId, data: Option, expected_task_count: Option, ) -> Result { self.db_client .update_task_state(&task_id, &state, data, expected_task_count) .await?; Ok(root_container_id) } /// Create new tasks generated by generator tasks. async fn create_tasks_from_generators( &self, tasks_to_create: HashMap<(Option, TaskId), Vec>, ) -> Result { debug!(task_to_create_count = %tasks_to_create.len(), "Creating new tasks from generators"); for ((container_id, generator_id), tasks) in tasks_to_create { self.db_client .create_tasks(&tasks, generator_id, container_id) .await?; } Ok(()) } /// For each root container ID, check if there are new tasks to start and dispatch them to the worker queue. /// /// # Arguments /// * `root_container_ids` - Set of root container IDs to check for new tasks /// /// Returns a Vec for all root containers. async fn schedule_next_tasks( &self, root_container_ids: HashSet, ) -> Result> { info!(root_container_count = %root_container_ids.len(), "Checking for new tasks to start after processing completions"); let mut bare_metal_worker_task_count = 0; let mut results = Vec::new(); for root_container_id in root_container_ids { debug!(root_container_id = %root_container_id, "Retrieving collection..."); let collection = self .db_client .get_task_collection(&root_container_id) .await?; match collection.next_tasks() { Ok(tasks) => { // Set tasks to running in the database // if tasks.is_empty() { // // NOTE: We're descending again into the task collection to determine if it's fully completed/failed. // // This is not ideal for compute perspective, but lets optimize if this becomes a bottleneck. // // This should only be hit when it's either running or completed. Failure cases are already handled // // by returning errors in the next_tasks() call. // let state = collection.state(); // match state { // TaskState::Succeeded => { // info!(root_container_id = %root_container_id, state = ?state, "Task collection state"); // } // TaskState::Failed | TaskState::PartiallyFailed => { // warn!(root_container_id = %root_container_id, state = ?state, "Task collection state"); // } // _ => { // debug!(root_container_id = %root_container_id, state = ?state, "Task collection state"); // } // } // continue; // } let task_ids = tasks.iter().map(|task| *task.id()).collect::>(); debug!(root_container_id = %root_container_id, task_count = %tasks.len(), task_ids = ?task_ids, task_ids_count = %task_ids.len(), "Found new tasks to start"); self.db_client.update_tasks_to_running(&task_ids).await?; // Send tasks to the worker queue for task in tasks { let worker_message = WorkerQueueMessage::StartTask { root_container_id, task: task.clone(), }; if task.requires_internal_worker() { debug!(task_id = %task.id(), "Task requires internal worker, sending to internal worker queue"); self.queue_client .send_int_worker_message(&worker_message) .await?; } else if task.requires_raw_sockets() { debug!(task_id = %task.id(), "Task requires raw socket access, sending to bare metal worker queue"); bare_metal_worker_task_count += 1; self.queue_client .send_bare_metal_worker_message(&worker_message) .await?; } else { debug!(task_id = %task.id(), "Task can be handled by external worker, sending to external worker queue"); self.queue_client .send_ext_worker_message(&worker_message) .await?; } info!(task_id = %task.id(), "Dispatched task to worker queue"); } } Err(err) => { // TODO: Implement proper error codes for this stuff so we can differentiate between code errors and task collection state errors. warn!(error=%err, root_container_id = %root_container_id, "Task collection failed"); } } // After processing, fetch the current state of the root container. results.push((root_container_id, collection.state())); } if bare_metal_worker_task_count > 0 { // Optional comma separated list of subnet IDs (different AZs) to try once each. // This is useful to catch any capacity errors in a single AZ. If provided, the // manager will attempt to launch instances in each subnet in order until one // succeeds or all subnets have been tried. // If not provided, we will try the subnet configured directly in the launch template. // // Example env var: BARE_METAL_WORKER_SUBNETS=subnet-aaa,subnet-bbb,subnet-ccc let subnet_ids: Vec = std::env::var("BARE_METAL_WORKER_SUBNETS") .ok() .map(|v| { v.split(',') .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .collect() }) .unwrap_or_default(); // Convert to HashSet for pop semantics (order not guaranteed, acceptable). let subnet_set: HashSet = subnet_ids.into_iter().collect(); self.start_bare_metal_workers(bare_metal_worker_task_count, subnet_set) .await?; info!(bare_metal_worker_task_count = %bare_metal_worker_task_count, "Number of tasks requiring bare metal workers dispatched"); } Ok(results) } async fn create_task_collections( &self, collections: Vec, ) -> Result> { debug!(task_collection_count = %collections.len(), "Creating new task collections"); let mut root_container_ids = HashSet::new(); for collection in collections { let root_container_id = *collection.root_container_id(); self.db_client.create_task_collection(collection).await?; root_container_ids.insert(root_container_id); } Ok(root_container_ids) } /// Start bare metal workers to handle tasks requiring raw socket access. /// /// TODO: Is this the right place for this logic? async fn start_bare_metal_workers( &self, task_count: i32, subnet_ids: HashSet, ) -> Result { let ec2_client = aws_sdk_ec2::Client::new(&aws_config::load_from_env().await); let task_count = if task_count > MAX_BARE_METAL_WORKERS { warn!(requested_task_count = %task_count, max_allowed = %MAX_BARE_METAL_WORKERS, "Requested bare metal worker task count exceeds maximum allowed, capping to maximum"); MAX_BARE_METAL_WORKERS } else { task_count }; if subnet_ids.is_empty() { debug!( "No subnet IDs provided for bare metal worker launch, using launch template default subnet and instance type" ); return self .start_bare_metal_worker_in_subnet(&ec2_client, task_count, None, None) .await; } // We should get it from a environment variable, or perhaps investigate launch template defaults. // aws_sdk_ec2::types::InstanceType::from(instance_type) // Note: the order for instance alternatives matters, we try None (launch template default) first. // For subnets we could benefit from the unordered set to introduce some randomness on which subnets // are tried first. // Check here for a nice filtering option on costs for the Arm64 architecture let instance_alternatives = vec![ None, Some(InstanceType::T4gMicro), Some(InstanceType::T4gSmall), Some(InstanceType::C8gdMedium), Some(InstanceType::C7gMedium), ]; for instance_type in instance_alternatives.into_iter() { for subnet_id in subnet_ids.iter() { debug!(subnet_id = %subnet_id, instance_type = ?instance_type, "Attempting to start bare metal workers in subnet"); match self .start_bare_metal_worker_in_subnet( &ec2_client, task_count, Some(subnet_id.clone()), instance_type.as_ref(), ) .await { Ok(_) => { info!(subnet_id = %subnet_id, instance_type = ?instance_type, "Successfully started bare metal workers in subnet"); return Ok(()); } Err(_) => { // Already logged in start_bare_metal_worker_in_subnet // warn!(error = %err, "Failed to start bare metal workers in subnet, trying next subnet if available"); } }; } } Err(Error::Other( "Could not start bare metal workers in any of the provided subnets or instance types" .into(), )) } /// Start bare metal workers in the specified subnet. /// /// # Arguments /// * `ec2_client` - AWS EC2 client to use for launching instances. /// * `task_count` - Number of bare metal worker instances to start. /// * `subnet_id` - Optional subnet ID to launch instances in. If not provided, we will rely on the launch template's default subnet. async fn start_bare_metal_worker_in_subnet( &self, ec2_client: &aws_sdk_ec2::Client, task_count: i32, subnet_id: Option, instance_type: Option<&InstanceType>, ) -> Result { let task_count = if task_count > MAX_BARE_METAL_WORKERS { warn!(requested_task_count = %task_count, max_allowed = %MAX_BARE_METAL_WORKERS, "Requested bare metal worker task count exceeds maximum allowed, capping to maximum"); MAX_BARE_METAL_WORKERS } else { task_count }; let launch_template = aws_sdk_ec2::types::LaunchTemplateSpecification::builder() .launch_template_name(BARE_METAL_WORKER_LAUNCH_TEMPLATE) .set_version(Some("$Latest".to_string())) .build(); // Create instances and add more tags to be able to track what the instances belong to // let additional_tags = aws_sdk_ec2::types::TagSpecification::builder() // .resource_type(aws_sdk_ec2::types::ResourceType::Instance) // .tags( // aws_sdk_ec2::types::Tag::builder() // .key("task_collection_ids") // .value( // task_collection_ids // .iter() // .map(|id| id.to_string()) // .collect::>() // .join(","), // ) // .build(), // ) // .tags( // aws_sdk_ec2::types::Tag::builder() // .key("Purpose") // .value("BareMetalWorker") // .build(), // ) // .build(); let mut run_instances = ec2_client .run_instances() .launch_template(launch_template) .min_count(task_count) .max_count(task_count); run_instances = if let Some(subnet_id) = subnet_id { debug!(subnet_id = %subnet_id, "Using provided subnet ID for bare metal worker launch"); let nic_config = aws_sdk_ec2::types::InstanceNetworkInterfaceSpecification::builder() .device_index(0) // This is needed to ensure we're overriding the primary network interface .subnet_id(subnet_id) // .associate_public_ip_address(true) // .interface_type("interface") // .description("mathijs-worker-ami. Custom subnet passed on through task manager.") // .delete_on_termination(true) // TODO: It seems AWS does inherit the security group for the nic which is nice. This might be the same // for the other settings, if so, we should only add subnet_id here. // .set_groups(None) // Use security groups from launch template. Will this work? I'm afraid we'll have to pass it on again.. :( .build(); run_instances .set_subnet_id(None) .network_interfaces(nic_config) } else { run_instances }; run_instances = if let Some(instance_type) = instance_type { debug!(instance_type = %instance_type, "Using provided instance type for bare metal worker launch"); run_instances.instance_type(instance_type.clone()) } else { run_instances }; match run_instances.send().await { Ok(_) => { info!(task_count = %task_count, "Successfully started bare metal workers"); } Err(err) => { // TODO: We should catch specific InsufficientInstanceCapacity error to decide whether to try another subnet. if let Some(service_error) = err.as_service_error() { match service_error.meta().code() { Some("InsufficientInstanceCapacity") => { warn!( "Insufficient instance capacity in this subnet/instance type, trying next subnet if available" ); } error_code => { error!(message = ?service_error.meta().message(), error_code = ?error_code,"Could not start bare metal workers"); } } } else { error!(error = ?err.raw_response(), "Could not start bare metal workers"); } return Err(Error::Other("Could not start bare metal workers".into())); } } Ok(()) } /// Process any due task collections by starting them and sending messages to Task Manager queue. /// /// # Arguments /// * `db` - Reference to the DbClient for database operations. /// * `queue_client` - Reference to the QueueClient for sending messages. async fn start_due_task_collections(&self) -> Result<(), Error> { info!("Checking for due jobs to start"); let started = self.db_client.start_due_jobs().await?; if started.is_empty() { info!("No task collections were due to start"); return Ok(()); } info!( started_count = started.len(), "Started scheduled task collections. Sending CreateTaskCollection messages to Task Manager queue..." ); // Iterate and attempt to serialize each started TaskCollection individually so that // a single failure does not abort processing of the rest. Collect successes and failures. let mut successful_messages: Vec = Vec::new(); let mut failed_collections = Vec::new(); for collection in started.into_iter() { debug!(collection = %collection, "Generating unique UUIDs for task collection"); let collection_instance = collection.generate_definition_with_uuids(); debug!( collection = %collection_instance.to_string(), "Creating TaskManagerQueueMessage from TaskCollection" ); match TaskManagerQueueMessage::task_collection_from_json(collection_instance) { Ok(msg) => successful_messages.push(msg), Err(e) => { error!(error = ?e, failed_collection = ?collection, "Failed to serialize TaskCollection into TaskManagerQueueMessage"); failed_collections.push(collection); } } } if successful_messages.is_empty() { info!( failed_count = failed_collections.len(), "No TaskManagerQueueMessages were successfully created; skipping send" ); } else { self.queue_client .send_mgr_messages(&successful_messages) .await?; info!( enqueued_count = successful_messages.len(), failed_count = failed_collections.len(), "Enqueued CreateTaskCollection messages for started collections" ); } if !failed_collections.is_empty() { // TODO: CREATE A CANCEL_TASK_COLLECTION MESSAGE FOR THE TASK MANAGER. error!( failed_collection_ids = failed_collections.len(), "Summary: some TaskCollections failed serialization" ); } Ok(()) } /// Update task status for any tasks that have exceeded their allowed runtime. /// /// # Arguments /// * `db` - Reference to the TaskDbClient for database operations. async fn process_task_timeouts(&self) -> Result, Error> { info!("Processing task timeouts"); let timed_out_task_ids = self.db_client.cancel_timed_out_tasks().await?; if timed_out_task_ids.is_empty() { info!("No timed-out tasks found"); Ok(timed_out_task_ids) } else { info!( timed_out_count = timed_out_task_ids.len(), "Found timed-out tasks" ); Ok(timed_out_task_ids) } } } ``` ### `shared/mrpf_core/src/tasks/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/mod.rs` ```rust //! Task Manager Models Module //! //! This module contains all core data structures, enums, and types used by the MRPF Task Manager. //! It defines the schema for tasks, containers, results, policies, and other entities required for //! distributed scan orchestration and scheduling. All types are designed for high performance, //! minimal memory usage, and easy serialization for inter-thread and inter-process communication. //! //! # Design Notes //! - Shared by both the task manager and other MRPF components (API, engine, etc.) //! - Follows MRPF conventions: minimal dependencies, extensibility, and ARM/Linux/macOS compatibility. //! - Types are optimized for stack allocation and zero-copy parsing where possible. //! //! # See Also //! - `mrpf_models` for global/shared types //! - Project book and README for schema diagrams and design notes #[cfg(feature = "task_shared")] mod execution_context; #[cfg(feature = "task_shared")] pub use execution_context::ExecutionContext; #[cfg(feature = "task_shared")] mod queue_client; #[cfg(feature = "task_shared")] pub use queue_client::QueueClient; #[cfg(feature = "task_shared")] pub mod models { mod completion_policy; mod messages; mod task; mod task_collection; mod task_container; mod task_definition; mod task_kind; mod task_result; mod task_state; mod worker_requirements; pub use completion_policy::CompletionPolicy; pub use messages::{TaskManagerQueueMessage, WorkerQueueMessage}; pub use task::{Task, TaskId}; pub use task_collection::TaskCollection; pub use task_container::{TaskContainer, TaskContainerChild, TaskContainerId}; pub use task_definition::{TaskDefinition, TaskDefinitionId}; pub use task_kind::TaskKind; pub use task_result::TaskResult; pub use task_state::TaskState; pub use worker_requirements::WorkerRequirements; } #[cfg(feature = "task_shared")] mod definitions { mod conditional; mod dns_scan_tasks; mod filter; mod insert_data; mod notifications; mod recon_api_tasks; mod tcp_syn_tasks; mod tls_scan_tasks; pub use conditional::Conditional; pub use dns_scan_tasks::{DnsScan, DnsScanAggregator, DnsScanGenerator}; pub use filter::Filter; pub use insert_data::InsertData; pub use notifications::{ ErrorNotification, GenericNotification, TaskCollectionResultNotification, }; pub use recon_api_tasks::{ListModels, UpsertModels}; pub use tcp_syn_tasks::{TcpSyn, TcpSynAggregator, TcpSynGenerator}; pub use tls_scan_tasks::{TlsScan, TlsScanAggregator, TlsScanGenerator}; } #[cfg(feature = "task_manager")] pub mod manager { mod handler; pub use handler::Manager; } ``` ### `shared/mrpf_core/src/tasks/models/completion_policy.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/completion_policy.rs` ```rust use serde::{Deserialize, Serialize}; use std::fmt; /// An enum that represents the completion policy of a Task. The task manager /// determines what to do when a task completes. /// /// Often we want to be able to manage what happens when a tasks in a sequential container fails. #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, sqlx::Type)] #[serde(rename_all = "snake_case")] #[sqlx(type_name = "completion_policy", rename_all = "snake_case")] pub enum CompletionPolicy { /// Fail the task collection if a task fails. This is the default setting for a task. FailOnFailure, /// Continue the task collection if a task partially fails /// /// This is useful if a task consists of multiple sub tasks, and only some of the sub task fail. ContinueOnPartialFailure, /// Continue the task collection if a task fails ContinueOnFailure, } impl fmt::Display for CompletionPolicy { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { CompletionPolicy::FailOnFailure => write!(f, "fail_on_failure"), CompletionPolicy::ContinueOnPartialFailure => write!(f, "continue_on_partial_failure"), CompletionPolicy::ContinueOnFailure => write!(f, "continue_on_failure"), } } } ``` ### `shared/mrpf_core/src/tasks/models/messages.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/messages.rs` ```rust use super::{Task, TaskCollection, TaskContainerId, TaskId, TaskState}; use crate::Result; use crate::templates::DataMap; use serde::{Deserialize, Serialize}; /// Messages destined for the task manager queue /// /// Naming convention of each enum is that they should describe the action being taken. #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] #[expect(clippy::large_enum_variant)] // TODO: Consider boxing large variants if performance becomes an issue pub enum TaskManagerQueueMessage { /// Create a task collection. /// /// This is the initial state of a single run of a set of tasks. /// After creating the full collection, the task manager will start /// the first task(s). CreateTaskCollection { collection: TaskCollection }, /// A new task that should be created in the database for an active set of tasks. /// /// This message contains the task details and the task id of the generator /// that sent the request. CreateTask { /// The task to be created task: Task, /// The ID of the task that generated this task generator_id: TaskId, /// The ID of the container to create the task in. If None, we will retrieve the parent container /// of the generator task and use that. parent_id: Option, }, /// Task completion message. /// /// After receiving a completion message, the task manager will update the /// task's status and may start new tasks if necessary. CompleteTask { /// The ID of the completed task task_id: TaskId, /// The final state of the completed task state: TaskState, /// The root container ID of the task collection this task belongs to /// /// This is used to identify which task collection to check for further tasks to start. root_container_id: TaskContainerId, /// Optional data produced by the task data: Option, /// Optional expected_task_count counter. Only applicable when /// state is `TaskState::PendingTaskCreation` and type is a container. expected_task_count: Option, }, } impl TaskManagerQueueMessage { /// Create a TaskCollection from a JSON string representing a TaskManagerQueueMessage pub fn task_collection_from_json(value: serde_json::Value) -> Result { let collection: TaskCollection = serde_json::from_value(value)?; Ok(TaskManagerQueueMessage::CreateTaskCollection { collection }) } } /// Messages destined for the workers queue /// /// Naming convention of each enum is that they should describe the action being taken. #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum WorkerQueueMessage { /// Starts a new task on a worker StartTask { /// The root container ID of the task collection. This is required /// to inform the Task Manager which task collection this task belongs to when /// it completes. root_container_id: TaskContainerId, /// The task to be started task: Task, }, } ``` ### `shared/mrpf_core/src/tasks/models/task.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/task.rs` ```rust use super::{CompletionPolicy, TaskKind, TaskResult, TaskState}; use crate::Result; use serde::{Deserialize, Serialize}; use sqlx::types::chrono::{DateTime, Utc}; use uuid::Uuid; #[cfg(feature = "task_execution")] use crate::tasks::ExecutionContext; #[cfg(feature = "sqlx")] use sqlx::{FromRow, Row, postgres::PgRow}; pub type TaskId = Uuid; /// A runtime task in the task manager. /// /// These tasks get scheduled by the task manager and executed by workers. The /// workers retrieve the required arguments from the task template snapshot. #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Task { /// The unique identifier of the task /// /// Task Id's will be generated by our code, no automatic generation /// will be done by the database. This is required to ensure idempotency /// as our queue offers at-least-once delivery. id: TaskId, /// The current state of the task state: TaskState, /// The completion policy for the task completion_policy: CompletionPolicy, /// The task template snapshot /// /// The arguments required by the workers to execute the task. These arguments are either: /// - Generated from an existing task template during task creation /// - Generated through a TaskGenerator task /// /// Example /// ======= /// /// ```json /// { /// "tcp_syn_generator": { /// "ipv4_addresses": ["192.168.0.0/16"], /// "ports": [443, 8443] /// } /// } /// ``` definition: TaskKind, // created_task_ids: Option>, // expected_task_count: Option, /// When the task transitioned to Running (UTC) started_at: Option>, } impl Task { pub fn id(&self) -> &TaskId { &self.id } pub fn state(&self) -> TaskState { self.state } pub fn completion_policy(&self) -> &CompletionPolicy { &self.completion_policy } pub fn definition(&self) -> &TaskKind { &self.definition } pub fn started_at(&self) -> Option<&DateTime> { self.started_at.as_ref() } pub fn timeout_secs(&self) -> u32 { self.definition.timeout_secs() } pub fn requires_internal_worker(&self) -> bool { self.definition.requires_internal_worker() } pub fn requires_raw_sockets(&self) -> bool { self.definition.requires_raw_sockets() } #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { self.definition.execute(context).await } pub fn new(completion_policy: CompletionPolicy, definition: TaskKind) -> Self { Task { id: TaskId::now_v7(), state: TaskState::Pending, completion_policy, definition, started_at: None, } } } #[cfg(feature = "sqlx")] impl FromRow<'_, PgRow> for Task { /// Constructs a Task from a database row. fn from_row(row: &PgRow) -> sqlx::Result { let id = row.try_get::("id")?; let state = row.try_get::("state")?; let completion_policy = row.try_get::("completion_policy")?; let definition = TaskKind::from_row(row)?; let started_at: Option> = row.try_get("started_at").unwrap_or(None); Ok(Task { id, state, completion_policy, definition, started_at, }) } } ``` ### `shared/mrpf_core/src/tasks/models/task_collection.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/task_collection.rs` ```rust use super::{Task, TaskContainer, TaskContainerId, TaskState}; use crate::Result; use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// A collection of tasks and task containers, organized in a hierarchical structure. /// /// The root of the hierarchy is a `TaskContainer` that does not have a parent. #[derive(Debug, Clone)] pub struct TaskCollection { root_container: TaskContainer, } impl TaskCollection { pub fn root_container_id(&self) -> &TaskContainerId { self.root_container.id() } /// Get the current state of the task collection, which is the state of the root container. pub fn state(&self) -> TaskState { self.root_container.state() } /// Consumes the TaskCollection and returns the root container. pub(crate) fn into_root_container(self) -> TaskContainer { self.root_container } pub(crate) fn new(root_container: TaskContainer) -> Result { Ok(Self { root_container }) } /// Get the next tasks that are ready to be executed in the task collection. /// /// TODO: Should return an enum with the various results pub(crate) fn next_tasks(&self) -> Result> { match self.root_container.next_tasks() { Ok(tasks) => Ok(tasks), Err(e) => Err(e), } } } /// Custom Serialize implementation to get rid of the outer struct when serializing. impl Serialize for TaskCollection { fn serialize(&self, serializer: S) -> Result where S: Serializer, { self.root_container.serialize(serializer) } } /// Custom Deserialize implementation to get rid of the outer struct when deserializing. impl<'de> Deserialize<'de> for TaskCollection { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { let root_container = TaskContainer::deserialize(deserializer)?; // We've removed the parent_id so can't check this anymore this way. Perhaps // there's another way? // if !root_container.is_root_container() { // return Err(serde::de::Error::custom( // "Root container must not have a parent_id", // )); // } Ok(TaskCollection { root_container }) } } ``` ### `shared/mrpf_core/src/tasks/models/task_container.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/task_container.rs` ```rust use super::{Task, TaskState}; use crate::{Error, Result}; use serde::{Deserialize, Serialize}; use std::any::Any; use std::fmt; use std::vec::IntoIter; use tracing::debug; use uuid::Uuid; #[cfg(feature = "sqlx")] use {sqlx::Row, sqlx::postgres::PgRow}; pub type TaskContainerId = Uuid; /// A child of a TaskContainer can either be another TaskContainer or a Task. #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "snake_case")] #[expect(clippy::large_enum_variant)] // TODO: Consider boxing large variants if performance becomes an issue pub enum TaskContainerChild { Container(TaskContainer), Task(Task), } impl TaskContainerChild { /// Returns the unique identifier of the child pub fn id(&self) -> &Uuid { match self { TaskContainerChild::Container(container) => &container.id, TaskContainerChild::Task(task) => task.id(), } } /// Returns the state of the child pub fn state(&self) -> TaskState { match self { TaskContainerChild::Container(container) => container.state(), TaskContainerChild::Task(task) => task.state(), } } } /// A container Task #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "snake_case")] pub struct TaskContainer { /// The unique identifier for the task id: TaskContainerId, /// The execution mode for the task execution_mode: TaskExecutionMode, /// The task children. These can be either TaskContainers or Tasks. /// The order of the children matters when the execution mode is Sequential. children: Vec, } impl TaskContainer { /// Returns a reference to the task container's unique identifier. pub fn id(&self) -> &TaskContainerId { &self.id } /// Returns the current state of the task container. /// /// The state is derived from the states of its children. pub fn state(&self) -> TaskState { // TODO: We're not taking CompletionPolicy into account. // NOTE: we strictly don't need to count everything but // makes it easier in the future to generate metrics. let mut failed_tasks: u16 = 0; let mut partially_failed_tasks: u16 = 0; let mut succeeded_tasks: u16 = 0; let mut running_tasks: u16 = 0; let mut pending_tasks: u16 = 0; let mut total_tasks: u16 = 0; for child in &self.children { let child_state = child.state(); match child_state { TaskState::Disabled => continue, TaskState::Failed | TaskState::Timeout => { failed_tasks += 1; } TaskState::Running | TaskState::PendingTaskCreation => { running_tasks += 1; } TaskState::Pending => { pending_tasks += 1; } TaskState::PartiallyFailed => { partially_failed_tasks += 1; } TaskState::Succeeded => { succeeded_tasks += 1; } } total_tasks += 1; } debug!( container_id = %self.id(), total_tasks = %total_tasks, failed_tasks = %failed_tasks, partially_failed_tasks = %partially_failed_tasks, succeeded_tasks = %succeeded_tasks, running_tasks = %running_tasks, pending_tasks = %pending_tasks, "Computed task container state from children" ); if failed_tasks == total_tasks && total_tasks > 0 { return TaskState::Failed; } else if failed_tasks > 0 && succeeded_tasks > 0 && running_tasks == 0 && pending_tasks == 0 { return TaskState::PartiallyFailed; } else if succeeded_tasks == total_tasks && total_tasks > 0 { return TaskState::Succeeded; } else if pending_tasks == total_tasks && total_tasks > 0 { return TaskState::Pending; } else if running_tasks > 0 { return TaskState::Running; } else if partially_failed_tasks > 0 { return TaskState::PartiallyFailed; } else if succeeded_tasks > 0 && pending_tasks > 0 { return TaskState::Pending; } TaskState::Disabled } /// Returns the execution mode of the task. pub fn execution_mode(&self) -> TaskExecutionMode { self.execution_mode } /// Consumes self and returns the children of the container. /// This is useful when inserting into the database. pub(crate) fn into_children(self) -> IntoIter { self.children.into_iter() } /// Replace the children of the container. pub(crate) fn replace_children(&mut self, children: Vec) { self.children = children; } /// Add a TaskContainer child to this container. /// /// The order of adding containers matters when the execution mode is Sequential. pub fn add_container_child(&mut self, container: TaskContainer) { self.children.push(TaskContainerChild::Container(container)); } /// Add a Task child to this container. /// /// The order of adding tasks matters when the execution mode is Sequential. pub fn add_task_child(&mut self, task: Task) { self.children.push(TaskContainerChild::Task(task)); } /// Initial creation of a new TaskContainer. pub fn new(execution_mode: TaskExecutionMode) -> Self { Self { id: TaskContainerId::now_v7(), execution_mode, children: Vec::new(), } } /// Initialize TaskContainer from a PostgreSQL row. #[cfg(feature = "sqlx")] pub fn from_row(row: &PgRow) -> Result { let id = row.try_get("id")?; let execution_mode = row.try_get("execution_mode")?; Ok(Self { id, execution_mode, children: Vec::new(), }) } /// Retrieve the next pending task. /// /// Returns an empty vector if no pending tasks are found. /// Returns an error if a child is in a failed state. pub fn next_tasks(&self) -> Result> { match &self.execution_mode() { TaskExecutionMode::Sequential => { debug!( container_id = %self.id(), children_count = %self.children.len(), "Found sequential container, checking children in order" ); for child in &self.children { match child.state() { TaskState::Running => { debug!(container_id = %self.id(), child_id = %child.id(), "Child is running, no new tasks can be started in sequential container"); return Ok(vec![]); } TaskState::Failed | TaskState::Timeout => { return Err(Error::Other( "Sequential container has a child in a failed state".into(), )); } TaskState::Pending => match child { TaskContainerChild::Task(task) => { debug!(task_id = %task.id(), container_id = %self.id(), "Found pending task in sequential container"); return Ok(vec![task]); } TaskContainerChild::Container(container) => { debug!(container_id = %container.id(), "Descending into child container of sequential container"); let tasks = container.next_tasks()?; if !tasks.is_empty() { debug!(container_id = %self.id(), child_container_id = %container.id(), task_count = %tasks.len(), "Found new tasks in child container"); return Ok(tasks); } } }, _ => continue, } } Ok(vec![]) } TaskExecutionMode::Parallel => { let mut tasks = Vec::new(); debug!( container_id = %self.id(), children_count = %self.children.len(), "Found parallel container, checking all children" ); for child in &self.children { match child.state() { TaskState::Running => { debug!(container_id = %self.id(), child_id = %child.id(), child_type = ?child.type_id(), "Child is running, no new tasks can be started in parallel container"); return Ok(vec![]); } TaskState::Failed | TaskState::Timeout => { return Err(Error::Other( "Parallel container has a child in a failed state".into(), )); } TaskState::Pending => match child { TaskContainerChild::Task(task) => { debug!(task_id = %task.id(), "Found pending task in parallel container"); tasks.push(task); } TaskContainerChild::Container(container) => { debug!(container_id = %container.id(), "Descending into child container of parallel container"); let child_tasks = container.next_tasks()?; debug!(container_id = %container.id(), task_count = %child_tasks.len(), "Found new tasks in child container"); tasks.extend(child_tasks); } }, _ => continue, } } debug!( container_id = %self.id(), task_count = %tasks.len(), "Found pending task(s) in parallel container" ); Ok(tasks) } } } } /// The execution mode of a task container determines if its /// child tasks should be run in parallel or sequential. #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, sqlx::Type)] #[serde(rename_all = "snake_case")] #[sqlx(type_name = "task_execution_mode", rename_all = "snake_case")] pub enum TaskExecutionMode { /// Run a task in sequence. Sequential, /// Runs in parallel with any subsequent parallel task in the tasks list. Parallel, } impl fmt::Display for TaskExecutionMode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { TaskExecutionMode::Sequential => write!(f, "sequential"), TaskExecutionMode::Parallel => write!(f, "parallel"), } } } ``` ### `shared/mrpf_core/src/tasks/models/task_definition.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/task_definition.rs` ```rust use crate::tasks::models::{TaskResult, WorkerRequirements}; use crate::{Error, Result}; use uuid::Uuid; #[cfg(feature = "task_execution")] pub use crate::tasks::ExecutionContext; pub type TaskDefinitionId = Uuid; /// A task definition in the task manager. /// /// Task definitions determine what need to be executed by a runtime task. When the task gets generated, /// a copy of all relevant data from the definition is used to create the runtime task. /// /// This allows us to re-use definitions for similar tasks and ensures changes to a task definition during an /// active tasks does not affect the task. /// /// TODO: Perhaps we should move TaskDefinition to our models crate? This will allow other crates to use it, e.g. the API crate. pub trait TaskDefinition { // The unique identifier of the definition task. // fn id(&self) -> &TaskDefinitionId; const WORKER_REQUIREMENTS: &'static [WorkerRequirements]; /// Static timeout for tasks of this definition in seconds. /// /// This is copied onto runtime tasks when they are created. Keeping it static avoids /// having to serialize it separately and enables cheap access via associated consts. const TIMEOUT_SECS: u32; fn worker_requirements(&self) -> &'static [WorkerRequirements] { Self::WORKER_REQUIREMENTS } /// Whether this task requires an internal worker. fn requires_internal_worker(&self) -> bool { (Self::WORKER_REQUIREMENTS.contains(&WorkerRequirements::TaskDataAccess) || Self::WORKER_REQUIREMENTS.contains(&WorkerRequirements::ReconDataAccess)) && !Self::WORKER_REQUIREMENTS.contains(&WorkerRequirements::RawSocketAccess) } /// Whether this task requires raw socket access. fn requires_raw_sockets(&self) -> bool { Self::WORKER_REQUIREMENTS.contains(&WorkerRequirements::RawSocketAccess) } /// Timeout accessor (in seconds) for convenience when storing onto runtime tasks. fn timeout_secs(&self) -> u32 { Self::TIMEOUT_SECS } /// Execute the task definition within the provided worker context. #[cfg(feature = "task_execution")] async fn execute(&mut self, _context: &ExecutionContext) -> Result { Err(Error::Other(format!( "External worker task execution not implemented for task template {}", std::any::type_name::() ))) } } ``` ### `shared/mrpf_core/src/tasks/models/task_kind.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/task_kind.rs` ```rust use crate::Result; use crate::tasks::ExecutionContext; use crate::tasks::definitions::{ Conditional, DnsScan, DnsScanAggregator, DnsScanGenerator, ErrorNotification, Filter, GenericNotification, InsertData, ListModels, TaskCollectionResultNotification, TcpSyn, TcpSynAggregator, TcpSynGenerator, TlsScan, TlsScanAggregator, TlsScanGenerator, UpsertModels, }; use crate::tasks::models::{TaskDefinition, TaskResult, WorkerRequirements}; use serde::{Deserialize, Serialize}; use std::fmt; #[cfg(feature = "sqlx")] use {crate::Error, crate::tasks::models::TaskId, sqlx::FromRow, sqlx::Row, sqlx::postgres::PgRow}; use mrpf_matchers::GenericMatcher; /// TaskKind holds all possible task types. /// /// The basic implementation of a task is to perform the specific task action (eg. TcpSyn scan or filter over some wordlist), update the task_data or job_data, /// and update the task_id state in the database. /// /// When building a task that can be distributed over multiple compute nodes in parallel we'll be using three separate tasks to coordinate the fan-out/fan-in process: /// - A generator task: Responsible for generating the actual tasks to perform. /// Will take in all parameters (e.g. wordlists, targets, max_concurrency, etc) and generate the appropriate tasks /// - A 'basic/normal' task: This represents the actual code that we want to run in a single distributed task. It will be responsible /// for performing the task, getting the results and storing them in the task_data table as JSONB blobs. /// - An aggregator task: Responsible for aggregating the results of multiple normal tasks. It will combine the results and /// store them in the job_data table as a single JSONB blob. Note that the job_data table is responsible for the full job /// datasets that can be used in subsequent tasks, whilst the task_data is only used for temporary storage of intermediate tasks. /// /// The reason why we want to leverage this three tiered strategy for generating tasks is to ensure our job manager can be run asynchronously /// and it won't end up in race conditions. TODO: How can we really achieve this? I guess we'd still need some kind of locking mechanism on the database /// when updating task's status field. /// /// # Fan-Out-Fan-In Task Generation Flow /// /// This section explains how the fan-out-fan-in pattern works for task generation and execution, /// and how the task queue evolves over time based on task state. /// /// ## Example Flow /// /// 1. **Generator Task Enqueued** /// - The job manager enqueues a `Generator` task (e.g., `TcpSynGenerator`). /// - The task queue: `[TcpSynGenerator]` /// /// 2. **Generator Task Runs** /// - The generator task runs, producing multiple "normal" tasks (e.g., `TcpSyn` for each target/port). /// - These tasks are added to the queue. /// - The task queue: `[TcpSyn(target1), TcpSyn(target2), ..., TcpSyn(targetN)]` /// /// 3. **Normal Tasks Execute in Parallel** /// - Each `TcpSyn` task runs independently (fan-out). /// - As tasks complete, their state is updated (`Succeeded`, `Failed`, etc.). /// - The queue is updated as tasks finish. /// /// 4. **Aggregator Task Enqueued** /// - Once all normal tasks reach a terminal state (e.g., all `TcpSyn` tasks are `Succeeded` or `Failed`), /// the job manager enqueues an `Aggregator` task (e.g., `TcpSynAggregator`). /// - The task queue: `[TcpSynAggregator]` /// /// 5. **Aggregator Task Runs** /// - The aggregator task collects results from all normal tasks, combines them, and stores the final output. /// - The queue is empty or moves to the next stage if more tasks are defined. /// /// ## Task State Dependency /// /// - The job manager only proceeds to the next stage (e.g., enqueuing the aggregator) when all tasks /// of the current stage are in a terminal state (`Succeeded`, `Failed`, `PartiallyFailed`, etc.). /// - This ensures correct fan-in behavior and prevents race conditions. /// /// ## Example Timeline /// /// | Time | Task Queue State | Action Taken | /// |------|--------------------------------------------------|-------------------------------------| /// | t0 | [TcpSynGenerator] | Generator task enqueued | /// | t1 | [TcpSyn(target1), ..., TcpSyn(targetN)] | Generator produced N tasks | /// | t2 | [TcpSyn(target1: Running), ..., TcpSyn(targetN)] | Normal tasks running in parallel | /// | t3 | [TcpSyn(target1: Succeeded), ..., TcpSyn(targetN: Succeeded)] | All normal tasks done | /// | t4 | [TcpSynAggregator] | Aggregator enqueued | /// | t5 | [] | Aggregator finished, job complete | /// /// This pattern allows for scalable, parallel execution with clear synchronization points. #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "kind", rename_all = "snake_case")] #[expect(clippy::large_enum_variant)] // TODO: Consider boxing large variants if performance becomes an issue pub enum TaskKind { /// Generate TcpSyn tasks TcpSynGenerator(TcpSynGenerator), /// Perform a TcpSyn scan TcpSyn(TcpSyn), /// Aggregate TcpSyn scan results TcpSynAggregator(TcpSynAggregator), /// Generate TlsScan tasks TlsScanGenerator(TlsScanGenerator), /// Perform a TlsScan TlsScan(TlsScan), /// Generate DNS Resolver tasks DnsScanGenerator(DnsScanGenerator), /// Perform a DNS Resolver task DnsScan(DnsScan), /// Aggregate DNS Resolver results DnsScanAggregator(DnsScanAggregator), /// Aggregate TlsScan results TlsScanAggregator(TlsScanAggregator), // TODO: Tmp commented out as i need to get this compatible with Clone // Filter on words in str // WordFilter(Filter), /// Filter on tertiary values True, False or None TertiaryFilter(Filter>, Option>), /// Filter on numeric values NumericFilter(Filter, i64>), /// Conditional task that evaluates a condition and decides the next steps Conditional(Conditional), /// Insert static data into the task collection's data map InsertData(InsertData), /// Get details for one or more Recon API models ListModels(ListModels), /// Upsert one or more Recon API models UpsertModels(UpsertModels), /// Send an error notification ErrorNotification(ErrorNotification), /// Send a generic notification GenericNotification(GenericNotification), /// Send a notification based on task results TaskResultNotification(TaskCollectionResultNotification), } #[cfg(feature = "sqlx")] impl FromRow<'_, PgRow> for TaskKind { // I've only built this only for the definition column in Task struct, not for an entry from the task_template table. // I'll need to handle those differently. fn from_row(row: &PgRow) -> sqlx::Result { let id: TaskId = row.try_get("id")?; let definition: serde_json::Value = row.try_get("definition")?; let kind: &str = definition .get("kind") .and_then(|k| k.as_str()) .ok_or_else(|| sqlx::Error::ColumnDecode { index: "kind".into(), source: Box::new(Error::Other( "Missing or invalid 'kind' field in definition".into(), )), })?; match kind { "tcp_syn_generator" => Ok(TaskKind::TcpSynGenerator(TcpSynGenerator::from_definition( id, definition, )?)), "tcp_syn" => Ok(TaskKind::TcpSyn(TcpSyn::from_definition(id, definition)?)), "tcp_syn_aggregator" => Ok(TaskKind::TcpSynAggregator( TcpSynAggregator::from_definition(id, definition)?, )), "tls_scan_generator" => Ok(TaskKind::TlsScanGenerator( TlsScanGenerator::from_definition(id, definition)?, )), "tls_scan" => Ok(TaskKind::TlsScan(TlsScan::from_definition(id, definition)?)), "tls_scan_aggregator" => Ok(TaskKind::TlsScanAggregator( TlsScanAggregator::from_definition(id, definition)?, )), "dns_scan_generator" => Ok(TaskKind::DnsScanGenerator( DnsScanGenerator::from_definition(id, definition)?, )), "dns_scan" => Ok(TaskKind::DnsScan(DnsScan::from_definition(id, definition)?)), "dns_scan_aggregator" => Ok(TaskKind::DnsScanAggregator( DnsScanAggregator::from_definition(id, definition)?, )), // "word_filter" => Ok(TaskKind::WordFilter( // Filter::::from_definition(id, definition)?, // )), "tertiary_filter" => Ok(TaskKind::TertiaryFilter(Filter::< GenericMatcher>, Option, >::from_definition( id, definition )?)), "numeric_filter" => Ok(TaskKind::NumericFilter( Filter::, i64>::from_definition(id, definition)?, )), "conditional" => Ok(TaskKind::Conditional(Conditional::from_definition( id, definition, )?)), "insert_data" => Ok(TaskKind::InsertData(InsertData::from_definition( id, definition, )?)), "list_models" => Ok(TaskKind::ListModels(ListModels::from_definition( id, definition, )?)), "upsert_models" => Ok(TaskKind::UpsertModels(UpsertModels::from_definition( id, definition, )?)), "error_notification" => Ok(TaskKind::ErrorNotification( ErrorNotification::from_definition(id, definition)?, )), "generic_notification" => Ok(TaskKind::GenericNotification( GenericNotification::from_definition(id, definition)?, )), "task_result_notification" => Ok(TaskKind::TaskResultNotification( TaskCollectionResultNotification::from_definition(id, definition)?, )), _ => Err(sqlx::Error::ColumnDecode { index: "kind".into(), source: Box::new(Error::Other(format!( "Unknown task type: {}. Update the from_row function on TaskKind", kind ))), }), } } } impl fmt::Display for TaskKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let kind_str = match self { TaskKind::TcpSynGenerator(_) => "tcp_syn_generator", TaskKind::TcpSyn(_) => "tcp_syn", TaskKind::TcpSynAggregator(_) => "tcp_syn_aggregator", TaskKind::TlsScanGenerator(_) => "tls_scan_generator", TaskKind::TlsScan(_) => "tls_scan", TaskKind::TlsScanAggregator(_) => "tls_scan_aggregator", TaskKind::DnsScanGenerator(_) => "dns_scan_generator", TaskKind::DnsScan(_) => "dns_scan", TaskKind::DnsScanAggregator(_) => "dns_scan_aggregator", // TaskKind::WordFilter(_) => "word_filter", TaskKind::TertiaryFilter(_) => "tertiary_filter", TaskKind::NumericFilter(_) => "numeric_filter", TaskKind::Conditional(_) => "conditional", TaskKind::InsertData(_) => "insert_data", TaskKind::ListModels(_) => "list_models", TaskKind::UpsertModels(_) => "upsert_models", TaskKind::ErrorNotification(_) => "error_notification", TaskKind::GenericNotification(_) => "generic_notification", TaskKind::TaskResultNotification(_) => "task_result_notification", }; write!(f, "{}", kind_str) } } // TaskKind implementation // ======================= /// Dynamically generate methods that call the underlying task implementations getters. /// /// This reduces the boiler plate code for each getter, and still avoid having to resort to dynamic dispatch. macro_rules! delegate_task_method { ($fn_name:ident, $ret:ty) => { pub fn $fn_name(&self) -> $ret { match self { TaskKind::TcpSynGenerator(task) => task.$fn_name(), TaskKind::TcpSyn(task) => task.$fn_name(), TaskKind::TcpSynAggregator(task) => task.$fn_name(), TaskKind::TlsScanGenerator(task) => task.$fn_name(), TaskKind::TlsScan(task) => task.$fn_name(), TaskKind::TlsScanAggregator(task) => task.$fn_name(), TaskKind::DnsScanGenerator(task) => task.$fn_name(), TaskKind::DnsScan(task) => task.$fn_name(), TaskKind::DnsScanAggregator(task) => task.$fn_name(), TaskKind::TertiaryFilter(task) => task.$fn_name(), // TaskKind::WordFilter(task) => task.$fn_name(), TaskKind::NumericFilter(task) => task.$fn_name(), TaskKind::Conditional(task) => task.$fn_name(), TaskKind::InsertData(task) => task.$fn_name(), TaskKind::ListModels(task) => task.$fn_name(), TaskKind::UpsertModels(task) => task.$fn_name(), TaskKind::ErrorNotification(task) => task.$fn_name(), TaskKind::GenericNotification(task) => task.$fn_name(), TaskKind::TaskResultNotification(task) => task.$fn_name(), } } }; } impl TaskKind { delegate_task_method!(requires_internal_worker, bool); delegate_task_method!(requires_raw_sockets, bool); delegate_task_method!(worker_requirements, &'static [WorkerRequirements]); delegate_task_method!(timeout_secs, u32); #[cfg(feature = "task_execution")] pub async fn execute(&mut self, context: &ExecutionContext) -> Result { match self { TaskKind::TcpSynGenerator(task) => task.execute(context).await, TaskKind::TcpSyn(task) => task.execute(context).await, TaskKind::TcpSynAggregator(task) => task.execute(context).await, TaskKind::TlsScanGenerator(task) => task.execute(context).await, TaskKind::TlsScan(task) => task.execute(context).await, TaskKind::TlsScanAggregator(task) => task.execute(context).await, TaskKind::DnsScanGenerator(task) => task.execute(context).await, TaskKind::DnsScan(task) => task.execute(context).await, TaskKind::DnsScanAggregator(task) => task.execute(context).await, TaskKind::TertiaryFilter(task) => task.execute(context).await, // TaskKind::WordFilter(task) => task.execute().await, TaskKind::NumericFilter(task) => task.execute(context).await, TaskKind::Conditional(task) => task.execute(context).await, TaskKind::InsertData(task) => task.execute(context).await, TaskKind::ListModels(task) => task.execute(context).await, TaskKind::UpsertModels(task) => task.execute(context).await, TaskKind::ErrorNotification(task) => task.execute(context).await, TaskKind::GenericNotification(task) => task.execute(context).await, TaskKind::TaskResultNotification(task) => task.execute(context).await, } } } ``` ### `shared/mrpf_core/src/tasks/models/task_result.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/task_result.rs` ```rust use super::TaskState; use serde::{Deserialize, Serialize}; use crate::templates::DataMap; /// The result of a task execution. /// /// TODO: This should likely be an enum so we can enforce certain states /// and their associated data. For example, a Generator task must have an /// expected_task_count. #[derive(Serialize, Deserialize, Debug, Clone)] pub enum TaskResult { /// A task was completed successfully without any data. Completed(TaskState), /// A task was completed successfully with data. CompletedWithData(TaskState, DataMap), /// A task was generated with a specific number of subtasks. GeneratedTasks(i16), } ``` ### `shared/mrpf_core/src/tasks/models/task_state.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/task_state.rs` ```rust use crate::Error; use serde::{Deserialize, Serialize}; use std::fmt; use std::str::FromStr; /// The state of a task /// /// The job manager will use the state to determine when to progress to the next /// task or finish the job. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)] #[sqlx(type_name = "task_state", rename_all = "snake_case")] #[serde(rename_all = "snake_case")] pub enum TaskState { /// The task is not scheduled yet Pending, /// The task is currently running Running, /// The task has completed generating create task events, and is waiting for all /// child tasks to be created by the task manager. PendingTaskCreation, /// The task has partially failed, meaning that some of the sub tasks have failed. /// This state can only be applied to TaskContainers PartiallyFailed, /// The task has succeeded Succeeded, /// The task has failed Failed, /// The task is disabled Disabled, /// The task has timed out Timeout, } impl TaskState { /// Returns true if the task is in a final state pub fn is_completed(&self) -> bool { matches!( self, TaskState::Succeeded | TaskState::Failed | TaskState::Disabled | TaskState::Timeout | TaskState::PartiallyFailed ) } } impl fmt::Display for TaskState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { TaskState::Pending => write!(f, "pending"), TaskState::Running => write!(f, "running"), TaskState::PendingTaskCreation => write!(f, "pending_task_creation"), TaskState::PartiallyFailed => write!(f, "partially_failed"), TaskState::Succeeded => write!(f, "succeeded"), TaskState::Failed => write!(f, "failed"), TaskState::Disabled => write!(f, "disabled"), TaskState::Timeout => write!(f, "timeout"), } } } impl FromStr for TaskState { type Err = Error; fn from_str(input: &str) -> Result { match input { "pending" => Ok(TaskState::Pending), "running" => Ok(TaskState::Running), "pending_task_creation" => Ok(TaskState::PendingTaskCreation), "partially_failed" => Ok(TaskState::PartiallyFailed), "succeeded" => Ok(TaskState::Succeeded), "failed" => Ok(TaskState::Failed), "disabled" => Ok(TaskState::Disabled), "timeout" => Ok(TaskState::Timeout), _ => Err("Invalid task state".into()), } } } ``` ### `shared/mrpf_core/src/tasks/models/worker_requirements.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/models/worker_requirements.rs` ```rust /// The required worker properties for a Task to be able to execute. /// /// A task can have multiple requirements, evaluated in an AND condition. #[derive(Debug, PartialEq)] pub enum WorkerRequirements { /// Direct database access to the data table. /// /// Useful for tasks that need to do transformations of our results, for example our Filter tasks. TaskDataAccess, /// Access to the task manager queue. /// /// Useful for generator tasks that need to create new tasks and send them to the queue. QueueAccess, /// Access to the MRPF recon database tables. /// /// Useful for tasks that need to read or write to the MRPF recon database, for example GetTarget tasks. ReconDataAccess, /// Raw socket access. /// /// Useful for tasks that need to send or receive raw packets, for example TCP SYN scan tasks. RawSocketAccess, } ``` ### `shared/mrpf_core/src/tasks/queue_client.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/tasks/queue_client.rs` ```rust use crate::Result; use crate::tasks::models::{TaskManagerQueueMessage, WorkerQueueMessage}; use std::env; use tracing::{debug, error}; /// Retrieve a string variable from the environment or use a default. fn get_env_or_default(var: &str, default: &str) -> String { env::var(var).unwrap_or_else(|_| default.to_string()) } pub struct QueueClient { /// The SQS Client client: aws_sdk_sqs::Client, /// The Task Manager SQS queue URL mgr_queue: String, /// The External Worker SQS queue URL ext_worker_queue: String, /// The Internal Worker SQS queue URL int_worker_queue: String, /// The Bare Metal Worker SQS queue URL bare_metal_worker_queue: String, } impl QueueClient { pub fn new(config: &aws_config::SdkConfig) -> Self { Self { client: aws_sdk_sqs::Client::new(config), mgr_queue: get_env_or_default("TASK_MANAGER_QUEUE", "sqs://Task Manager Queue"), ext_worker_queue: get_env_or_default( "EXTERNAL_WORKER_QUEUE", "sqs://External Worker Queue", ), int_worker_queue: get_env_or_default( "INTERNAL_WORKER_QUEUE", "sqs://Internal Worker Queue", ), bare_metal_worker_queue: get_env_or_default( "BARE_METAL_WORKER_QUEUE", "sqs://Bare Metal Worker Queue", ), } } // Send a single TaskManagerQueueMessage to the manager SQS queue. pub async fn send_mgr_message(&self, message: &TaskManagerQueueMessage) -> Result { let json = serde_json::to_value(message)?; self.send_message(&json, &self.mgr_queue).await } /// Send multiple TaskManagerQueueMessages to the manager SQS queue concurrently. pub async fn send_mgr_messages(&self, messages: &[TaskManagerQueueMessage]) -> Result { let mut serialized = Vec::with_capacity(messages.len()); for msg in messages { serialized.push(serde_json::to_value(msg)?); } self.send_messages(&serialized, &self.mgr_queue).await } /// Send a single WorkerQueueMessage to the external worker SQS queue. pub async fn send_int_worker_message(&self, message: &WorkerQueueMessage) -> Result { let json = serde_json::to_value(message)?; self.send_message(&json, &self.int_worker_queue).await } /// Send multiple WorkerQueueMessages to the internal worker SQS queue concurrently. pub async fn send_int_worker_messages(&self, messages: &[WorkerQueueMessage]) -> Result { let mut serialized = Vec::with_capacity(messages.len()); for msg in messages { serialized.push(serde_json::to_value(msg)?); } self.send_messages(&serialized, &self.int_worker_queue) .await } /// Send a single WorkerQueueMessage to the external worker SQS queue. pub async fn send_ext_worker_message(&self, message: &WorkerQueueMessage) -> Result { let json = serde_json::to_value(message)?; self.send_message(&json, &self.ext_worker_queue).await } /// Send multiple WorkerQueueMessages to the external worker SQS queue concurrently. pub async fn send_ext_worker_messages(&self, messages: &[WorkerQueueMessage]) -> Result { let mut serialized = Vec::with_capacity(messages.len()); for msg in messages { serialized.push(serde_json::to_value(msg)?); } self.send_messages(&serialized, &self.ext_worker_queue) .await } /// Send a single WorkerQueueMessage to the bare metal worker SQS queue. pub async fn send_bare_metal_worker_message(&self, message: &WorkerQueueMessage) -> Result { let json = serde_json::to_value(message)?; self.send_message(&json, &self.bare_metal_worker_queue) .await } /// Send multiple WorkerQueueMessages to the bare metal worker SQS queue concurrently. pub async fn send_bare_metal_worker_messages(&self, messages: &[WorkerQueueMessage]) -> Result { let mut serialized = Vec::with_capacity(messages.len()); for msg in messages { serialized.push(serde_json::to_value(msg)?); } self.send_messages(&serialized, &self.bare_metal_worker_queue) .await } /// Receive a single message from the bare metal worker SQS queue. /// /// The internal and external worker queues will use the normal AWS Lambda triggers for SQS. The /// bare metal worker queue is polled manually by the bare metal worker service. pub async fn receive_bare_metal_worker_message(&self) -> Result> { let message = self.receive_message(&self.bare_metal_worker_queue).await?; match message { Some(json) => { debug!("Received bare metal worker message: {:?}", json); Ok(serde_json::from_value(json)?) } None => Ok(None), } } /// Receive a single message from the specified SQS queue. /// /// If successfully retrieved, delete the message from the queue. NOTE: This means that if processing fails, /// the message is lost, and our code cannot requeue it. We may want to revisit this decision later. async fn receive_message(&self, queue_url: &str) -> Result> { let response = self .client .receive_message() .queue_url(queue_url) .max_number_of_messages(1) .wait_time_seconds(10) .send() .await?; if let Some(messages) = response.messages && let Some(message) = messages.into_iter().next() && let Some(body) = message.body { debug!("Received SQS message body: {}", body); // Parse JSON first. If parsing fails, return the error and leave the message // on the queue so it can be retried or put on the dead-letter queue. let json: serde_json::Value = match serde_json::from_str(&body) { Ok(j) => j, Err(e) => { error!(error = %e, "Failed to parse SQS message body as JSON"); return Err(e.into()); } }; // Delete the message after successful parsing. if let Some(receipt_handle) = message.receipt_handle { if let Err(e) = self .client .delete_message() .queue_url(queue_url) .receipt_handle(receipt_handle) .send() .await { error!(error = %e, "Failed to delete SQS message after processing"); } } else { error!("SQS message missing receipt handle; cannot delete"); } return Ok(Some(json)); } Ok(None) } /// Send a single message to the specified SQS queue. /// /// # Arguments /// * `message` - The message to send, as a serde_json::Value. /// * `queue` - The URL of the SQS queue to send the message to async fn send_message(&self, message: &serde_json::Value, queue: &str) -> Result { // Serialize message to string (avoid heap if possible, but SQS requires String) let msg_body = serde_json::to_string(message)?; // Send message to SQS let resp = self .client .send_message() .queue_url(queue) .message_body(msg_body) .send() .await; // Minimal error chain: map AWS error to our Result match resp { Ok(_) => Ok(()), Err(e) => Err(e.into()), } } /// Send multiple messages to the specified SQS queue concurrently (private, generic). async fn send_messages(&self, messages: &[serde_json::Value], queue: &str) -> Result { let futs = messages .iter() .map(|json| self.send_message(json, queue)) .collect::>(); // Send all messages concurrently, fail fast on first error let results = futures::future::join_all(futs).await; for result in results { result? } Ok(()) } } ``` ### `shared/mrpf_core/src/templates/data_kind.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/templates/data_kind.rs` ```rust use crate::Error; use crate::recon::{Domain, Ipv4Address, Target, Word}; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgRow; use sqlx::{FromRow, Row}; use std::collections::{HashMap, HashSet}; /// A type alias for a collection of data kinds. /// /// This allows us to unpack and manage different types of data for a single task within /// the database client. pub type DataMap = HashMap; /// Represents the kind of data stored in our task manager. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(tag = "kind", content = "data", rename_all = "snake_case")] #[expect(clippy::large_enum_variant)] // TODO: Consider boxing large variants if performance becomes an issue pub enum DataKind { SetOfStrings(HashSet), SetOfIntegers(HashSet), SetOfFqdns(HashSet), SetOfPorts(HashSet), MapOfDnsRecords(HashMap>>), Target(Target), SetOfTargets(HashSet), Ipv4Address(Ipv4Address), SetOfIpv4Addresses(HashSet), Domain(Domain), SetOfDomains(HashSet), Word(Word), SetOfWords(HashSet), } impl std::fmt::Display for DataKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let kind_str = match self { DataKind::SetOfStrings(_) => "set_of_strings", DataKind::SetOfIntegers(_) => "set_of_integers", DataKind::SetOfFqdns(_) => "set_of_fqdns", DataKind::SetOfPorts(_) => "set_of_ports", DataKind::MapOfDnsRecords(_) => "map_of_dns_records", DataKind::Target(_) => "target", DataKind::SetOfTargets(_) => "set_of_targets", DataKind::Ipv4Address(_) => "ipv4_address", DataKind::SetOfIpv4Addresses(_) => "set_of_ipv4_addresses", DataKind::Domain(_) => "domain", DataKind::SetOfDomains(_) => "set_of_domains", DataKind::Word(_) => "word", DataKind::SetOfWords(_) => "set_of_words", }; write!(f, "{}", kind_str) } } impl DataKind { /// Get the number of items in the DataKind. pub fn len(&self) -> usize { match self { DataKind::SetOfStrings(set) => set.len(), DataKind::SetOfIntegers(set) => set.len(), DataKind::SetOfFqdns(set) => set.len(), DataKind::SetOfPorts(set) => set.len(), DataKind::MapOfDnsRecords(map) => map.len(), DataKind::Target(_) => 1, DataKind::SetOfTargets(set) => set.len(), DataKind::Ipv4Address(_) => 1, DataKind::SetOfIpv4Addresses(set) => set.len(), DataKind::Domain(_) => 1, DataKind::SetOfDomains(set) => set.len(), DataKind::Word(_) => 1, DataKind::SetOfWords(set) => set.len(), } } /// Check if the DataKind is empty. pub fn is_empty(&self) -> bool { self.len() == 0 } } #[cfg(feature = "sqlx")] impl FromRow<'_, PgRow> for DataKind { fn from_row(row: &PgRow) -> sqlx::Result { let value: serde_json::Value = row.try_get("value")?; // If this fails, we likely have forgotten to cast "kind" as text in our SQL query. Some error like this: // error occurred while decoding column "kind": mismatched types; Rust type `alloc::string::String` (as SQL type `TEXT`) is not compatible with SQL type `data_kind` let kind: String = row.try_get("kind")?; match kind.as_str() { "set_of_strings" => value .as_array() .map(|arr| { DataKind::SetOfStrings( arr.iter() .filter_map(|v| v.as_str().map(|s| s.to_string())) .collect(), ) }) .ok_or_else(|| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse set_of_strings".into())), }), "set_of_integers" => value .as_array() .map(|arr| DataKind::SetOfIntegers(arr.iter().filter_map(|v| v.as_i64()).collect())) .ok_or_else(|| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse set_of_integers".into())), }), "set_of_fqdns" => value .as_array() .map(|arr| { DataKind::SetOfFqdns( arr.iter() .filter_map(|v| v.as_str().map(|s| s.to_string())) .collect(), ) }) .ok_or_else(|| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse set_of_fqdns".into())), }), "set_of_ports" => value .as_array() .map(|arr| { DataKind::SetOfPorts( arr.iter() .filter_map(|v| v.as_u64().and_then(|n| u16::try_from(n).ok())) .collect(), ) }) .ok_or_else(|| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse set_of_ports".into())), }), "map_of_dns_records" => { let map: HashMap>> = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse map_of_dns_records".into())), })?; Ok(DataKind::MapOfDnsRecords(map)) } "target" => { let target: Target = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse target".into())), })?; Ok(DataKind::Target(target)) } "set_of_targets" => { let targets: HashSet = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse set_of_targets".into())), })?; Ok(DataKind::SetOfTargets(targets)) } "ipv4_address" => { let ipv4_address: Ipv4Address = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse ipv4_address".into())), })?; Ok(DataKind::Ipv4Address(ipv4_address)) } "set_of_ipv4_addresses" => { let ipv4_addresses: HashSet = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other( "Failed to parse set_of_ipv4_addresses".into(), )), })?; Ok(DataKind::SetOfIpv4Addresses(ipv4_addresses)) } "domain" => { let domain: Domain = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse domain".into())), })?; Ok(DataKind::Domain(domain)) } "set_of_domains" => { let domains: HashSet = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse set_of_domains".into())), })?; Ok(DataKind::SetOfDomains(domains)) } "word" => { let word: Word = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse word".into())), })?; Ok(DataKind::Word(word)) } "set_of_words" => { let words: HashSet = serde_json::from_value(value).map_err(|_e| sqlx::Error::ColumnDecode { index: "value".into(), source: Box::new(Error::Other("Failed to parse set_of_words".into())), })?; Ok(DataKind::SetOfWords(words)) } _ => Err(sqlx::Error::ColumnDecode { index: "kind".into(), source: Box::new(Error::Other(format!("Unknown kind: {kind}"))), }), } } } ``` ### `shared/mrpf_core/src/templates/mod.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/templates/mod.rs` ```rust mod data_kind; mod parser; mod template; pub use data_kind::{DataKind, DataMap}; use template::{RuntimeExpr, TemplatePart}; pub use template::{Template, Templates}; ``` ### `shared/mrpf_core/src/templates/parser.rs` **Absolute path:** `/root/mrpf/shared/mrpf_core/src/templates/parser.rs` ```rust use super::{RuntimeExpr, Template, TemplatePart}; use crate::{Error, Result}; /// Parse an entire template string into a sequence of template parts (literals and expressions). /// /// This is the main entry point for parsing templates. It splits the input into literal and /// expression segments, and parses each expression using the recursive descent parser. pub fn parse_template(input: &str) -> Result