|
use petgraph::algo::toposort; |
|
use petgraph::graph::DiGraph; |
|
use petgraph::graph::NodeIndex; |
|
use serde::Deserialize; |
|
use serde_json::Result; |
|
use std::collections::HashMap; |
|
use std::sync::{Arc, Mutex}; |
|
use tokio::task; |
|
|
|
#[derive(Clone)] // Deriving Clone for Step |
|
struct Step { |
|
name: String, |
|
inputs: Vec<String>, |
|
outputs: Vec<String>, |
|
func: Arc<dyn Fn(Arc<Mutex<HashMap<String, String>>>) + Send + Sync>, |
|
} |
|
|
|
impl Step { |
|
fn new( |
|
name: &str, |
|
inputs: Vec<String>, |
|
outputs: Vec<String>, |
|
func: Arc<dyn Fn(Arc<Mutex<HashMap<String, String>>>) + Send + Sync>, |
|
) -> Self { |
|
Step { |
|
name: name.to_string(), |
|
inputs, |
|
outputs, |
|
func, |
|
} |
|
} |
|
|
|
// Executes the step, using shared state for input/output passing |
|
async fn execute(&self, shared_state: Arc<Mutex<HashMap<String, String>>>) { |
|
println!("Executing step: {}", self.name); |
|
(self.func)(shared_state); |
|
} |
|
} |
|
|
|
// Struct representing a Step in the JSON input |
|
#[derive(Debug, Deserialize)] |
|
struct StepConfig { |
|
name: String, |
|
inputs: Vec<String>, |
|
outputs: Vec<String>, |
|
function: String, |
|
} |
|
|
|
// Struct for the entire configuration |
|
#[derive(Debug, Deserialize)] |
|
struct DagConfig { |
|
steps: Vec<StepConfig>, |
|
} |
|
|
|
// Function to parse the JSON input into DagConfig |
|
fn parse_dag_from_json(json_input: &str) -> Result<DagConfig> { |
|
serde_json::from_str(json_input) |
|
} |
|
|
|
// Function to dynamically map function names to actual function implementations |
|
fn get_function_map( |
|
) -> HashMap<String, Arc<dyn Fn(Arc<Mutex<HashMap<String, String>>>) + Send + Sync>> { |
|
let mut map = HashMap::new(); |
|
map.insert( |
|
"step1_function".to_string(), |
|
Arc::new(|shared_state: Arc<Mutex<HashMap<String, String>>>| { |
|
println!("Step 1 executed, produced data1"); |
|
let mut state = shared_state.lock().unwrap(); |
|
state.insert("data1".to_string(), "value_from_step1".to_string()); |
|
}) as Arc<dyn Fn(Arc<Mutex<HashMap<String, String>>>) + Send + Sync>, |
|
); |
|
map.insert( |
|
"step2_function".to_string(), |
|
Arc::new(|shared_state: Arc<Mutex<HashMap<String, String>>>| { |
|
let state = shared_state.lock().unwrap(); |
|
if let Some(data1) = state.get("data1") { |
|
println!("Step 2 executed, used data1: {}", data1); |
|
} |
|
drop(state); // Unlock before adding new data |
|
let mut state = shared_state.lock().unwrap(); |
|
state.insert("data2".to_string(), "value_from_step2".to_string()); |
|
}) as Arc<dyn Fn(Arc<Mutex<HashMap<String, String>>>) + Send + Sync>, |
|
); |
|
map.insert( |
|
"step3_function".to_string(), |
|
Arc::new(|shared_state: Arc<Mutex<HashMap<String, String>>>| { |
|
let state = shared_state.lock().unwrap(); |
|
if let Some(data2) = state.get("data2") { |
|
println!("Step 3 executed, used data2: {}", data2); |
|
} |
|
}) as Arc<dyn Fn(Arc<Mutex<HashMap<String, String>>>) + Send + Sync>, |
|
); |
|
map |
|
} |
|
|
|
// Function to build the DAG from the parsed configuration |
|
fn build_dag_from_config(config: &DagConfig) -> DiGraph<Arc<Step>, ()> { |
|
let mut dag = DiGraph::<Arc<Step>, ()>::new(); |
|
let mut step_nodes: HashMap<String, NodeIndex> = HashMap::new(); |
|
|
|
// Get the function map |
|
let func_map = get_function_map(); |
|
|
|
// Add steps to the graph |
|
for step_config in &config.steps { |
|
// Find the function by name |
|
let func = func_map |
|
.get(&step_config.function) |
|
.expect("Function not found") |
|
.clone(); |
|
|
|
// Create a new step |
|
let step = Step::new( |
|
&step_config.name, |
|
step_config.inputs.clone(), |
|
step_config.outputs.clone(), |
|
func, |
|
); |
|
|
|
// Add the step to the graph |
|
let node = dag.add_node(Arc::new(step)); |
|
step_nodes.insert(step_config.name.clone(), node); |
|
} |
|
|
|
// Create edges based on dependencies (inputs and outputs) |
|
for step_config in &config.steps { |
|
for input in &step_config.inputs { |
|
for output_step in &config.steps { |
|
if output_step.outputs.contains(input) { |
|
if let (Some(&from), Some(&to)) = ( |
|
step_nodes.get(&output_step.name), |
|
step_nodes.get(&step_config.name), |
|
) { |
|
dag.add_edge(from, to, ()); |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
dag |
|
} |
|
|
|
// Function to execute steps in topological order |
|
async fn execute_steps( |
|
dag: Arc<DiGraph<Arc<Step>, ()>>, |
|
shared_state: Arc<Mutex<HashMap<String, String>>>, |
|
) { |
|
let sorted_steps = toposort(&*dag, None).expect("Cyclic dependency detected"); |
|
|
|
// Spawn async tasks for each step |
|
for step_idx in sorted_steps { |
|
let step = dag[step_idx].clone(); // Clone Arc for task safety |
|
let shared_state = shared_state.clone(); |
|
task::spawn(async move { |
|
step.execute(shared_state).await; |
|
}) |
|
.await |
|
.unwrap(); |
|
} |
|
} |
|
|
|
#[tokio::main] |
|
async fn main() { |
|
// Example JSON input |
|
let json_input = r#" |
|
{ |
|
"steps": [ |
|
{ |
|
"name": "Step 1", |
|
"inputs": [], |
|
"outputs": ["data1"], |
|
"function": "step1_function" |
|
}, |
|
{ |
|
"name": "Step 2", |
|
"inputs": ["data1"], |
|
"outputs": ["data2"], |
|
"function": "step2_function" |
|
}, |
|
{ |
|
"name": "Step 3", |
|
"inputs": ["data2"], |
|
"outputs": ["result"], |
|
"function": "step3_function" |
|
} |
|
] |
|
} |
|
"#; |
|
|
|
// Parse the JSON input |
|
let config = parse_dag_from_json(json_input).expect("Invalid JSON"); |
|
|
|
// Build the DAG from the configuration (borrow the config) |
|
let dag = Arc::new(build_dag_from_config(&config)); |
|
|
|
// Shared state for passing data between steps |
|
let shared_state = Arc::new(Mutex::new(HashMap::new())); |
|
|
|
// Execute steps in topological order |
|
execute_steps(dag.clone(), shared_state).await; |
|
} |