Created
May 20, 2024 03:50
-
-
Save mooreniemi/08ab3dde1c8236980f0f86df33ab17e5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "cete_node" | |
version = "0.1.0" | |
edition = "2021" | |
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |
[dependencies] | |
actix-web = "4" | |
aws-config = "*" | |
aws-sdk-codecommit = "*" | |
aws-sdk-ecs = "*" | |
clap = { version = "4.5", features = ["derive"] } | |
clap_derive = "4.5.4" | |
env_logger = "0.9" | |
etcd-client = "0.12" | |
log = "0.4" | |
port_scanner = "*" | |
reqwest = { version = "0.11", features = ["json"] } | |
serde = { version = "1", features = ["derive"] } | |
serde_json = "1" | |
tokio = { version = "1", features = ["full"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use actix_web::{ | |
middleware, | |
web::{self}, | |
App, HttpResponse, HttpServer, Responder, | |
}; | |
use aws_config::meta::region::RegionProviderChain; | |
use aws_sdk_codecommit as codecommit; | |
use clap::Parser; | |
use etcd_client::{Client, Compare, CompareOp, Txn, TxnOp}; | |
use serde::{Deserialize, Serialize}; | |
use serde_json::json; | |
use std::{ | |
collections::BTreeMap, | |
env, fmt, | |
net::TcpListener, | |
sync::{Arc, Mutex}, | |
}; | |
use tokio::{ | |
signal::unix::{signal, SignalKind}, | |
sync::Mutex as AsyncMutex, | |
}; | |
/// How many tasks are serving this shard? | |
#[derive(Serialize, Deserialize, Debug)] | |
struct ShardData { | |
count: i32, | |
} | |
/// Which shard is this task assigned to serve? | |
#[derive(Serialize, Deserialize, Debug)] | |
struct TaskData { | |
shard_id: usize, | |
} | |
async fn is_task_up(port: u16) -> bool { | |
let url = format!("http://localhost:{}/me", port); | |
match reqwest::get(url).await { | |
Ok(response) => response.status().as_u16() == 200, | |
Err(e) => { | |
// note: we expect this error so we don't log as error here | |
log::debug!("{}", e); | |
false | |
} | |
} | |
} | |
async fn show_tasks(data: web::Data<AppState>) -> impl Responder { | |
log::info!("showing tasks across the entire cluster"); | |
let mut client = data.etcd.lock().await; | |
// note: less confusing if output is in sorted order always | |
let mut tasks = BTreeMap::new(); | |
let prefix = "task/"; | |
log::info!("searching prefix: {}", prefix); | |
let response = client | |
.get(prefix, Some(etcd_client::GetOptions::new().with_prefix())) | |
.await | |
.expect("got keys"); | |
log::info!("found some keys"); | |
for kv in response.kvs() { | |
let key = String::from_utf8_lossy(kv.key()); | |
let port = key | |
.split("/") | |
.nth(1) | |
.unwrap_or_default() | |
.parse::<u16>() | |
.expect("valid port"); | |
log::info!("checking key: {}", key); | |
let value = serde_json::from_slice::<TaskData>(&kv.value()) | |
.expect("valid shard data") | |
.shard_id; | |
tasks.insert(key, (value, is_task_up(port).await)); | |
} | |
HttpResponse::Ok().json(tasks) | |
} | |
async fn show_assignments(data: web::Data<AppState>) -> impl Responder { | |
log::info!("showing assignments across the entire cluster"); | |
let mut client = data.etcd.lock().await; | |
// note: less confusing if output is in sorted order always | |
let mut shards = BTreeMap::new(); | |
for n in 0..data.total_shards { | |
let prefix = format!("shard_counts/shard_{}", n); | |
log::info!("searching prefix: {}", prefix); | |
let response = client | |
.get( | |
prefix.clone(), | |
Some(etcd_client::GetOptions::new().with_prefix()), | |
) | |
.await | |
.expect("got keys"); | |
log::info!("found some keys"); | |
for kv in response.kvs() { | |
let key = String::from_utf8_lossy(kv.key()); | |
log::info!("checking key: {}", key); | |
let value = serde_json::from_slice::<ShardData>(&kv.value()) | |
.expect("valid shard data") | |
.count; | |
shards.insert(prefix.clone(), value); | |
} | |
} | |
HttpResponse::Ok().json(shards) | |
} | |
async fn show_assignment(data: web::Data<AppState>) -> impl Responder { | |
let mut shard_assignment = data.shard_id.lock().expect("got shard_id"); | |
let mut shard_assignment_value = shard_assignment.clone(); | |
if shard_assignment_value.is_none() { | |
log::warn!("initializing shard id since none was found"); | |
let new_shard = { | |
determine_shard_assignment( | |
&data.etcd, | |
data.total_shards, | |
data.replicas_per_shard, | |
data.port, | |
) | |
.await | |
.expect("can determine shard") | |
}; | |
shard_assignment_value = Some(new_shard); | |
*shard_assignment = shard_assignment_value; | |
} | |
HttpResponse::Ok().json(json!({ | |
"shard": shard_assignment_value.unwrap(), | |
"task": data.port, | |
"role": data.role, | |
})) | |
} | |
async fn cleanup_assignment(etcd: &AsyncMutex<Client>, port: u16) { | |
log::info!("I am task: {}. Shutting down.", port); | |
let mut client = etcd.lock().await; | |
let task_key = format!("task/{}", port); | |
let response = client | |
.get(task_key.clone(), None) | |
.await | |
.expect("got task info"); | |
let shard_id = response | |
.kvs() | |
.get(0) | |
.map(|kv| { | |
serde_json::from_slice::<TaskData>(&kv.value()) | |
.unwrap() | |
.shard_id | |
}) | |
.unwrap_or(0) as usize; | |
log::info!( | |
"As task {}, I handled {}. This assignment will be released.", | |
port, | |
shard_id | |
); | |
let shard_key = format!("shard_counts/shard_{}", shard_id); | |
let response = client | |
.get(shard_key.clone(), None) | |
.await | |
.expect("got shard counts"); | |
let count = response | |
.kvs() | |
.get(0) | |
.map(|kv| { | |
serde_json::from_slice::<ShardData>(&kv.value()) | |
.unwrap() | |
.count | |
}) | |
.expect("should have non-zero shards counted") as usize; | |
log::info!( | |
"As task {}, I handled {}, which will have count {} after release.", | |
port, | |
shard_id, | |
count - 1 | |
); | |
let value = json!({ "count": count - 1 }).to_string(); | |
client | |
.put(shard_key.clone(), value, None) | |
.await | |
.expect("decremented count"); | |
client.delete(task_key, None).await.expect("deleted task"); | |
log::info!("Task {} finished, shard released.", port); | |
} | |
// given how the calculation strategy works, you'll almost always be yourself again | |
async fn recalculate_assignment(data: web::Data<AppState>) -> impl Responder { | |
let response = match determine_shard_assignment( | |
&data.etcd, | |
data.total_shards, | |
data.replicas_per_shard, | |
data.port, | |
) | |
.await | |
{ | |
Ok(shard_id) => { | |
json!({"shard_id": shard_id}) | |
} | |
Err(e) => { | |
json!({"error": format!("{}", e)}) | |
} | |
}; | |
HttpResponse::Ok().json(response) | |
} | |
// uses transactions to safely attempt to initialize or increment shard | |
// we don't retry in this function because if the cas failed we know another node took the assignment | |
// so we actually need to totally bail out here and continue to the next potentially available slot | |
async fn increment_shard_count( | |
client: &mut Client, | |
key: &str, | |
potential_initial_value: Option<usize>, | |
) -> Result<(), Box<dyn std::error::Error>> { | |
let key = key.to_string(); | |
log::info!( | |
"will increment {:?}, from {:?}", | |
key, | |
potential_initial_value | |
); | |
// note: creating the data for the first time must be handled differently than mutating it once it exists | |
let txn = match potential_initial_value { | |
Some(iv) => { | |
let initial_value_as_json = json!({"count": iv + 1}).to_string(); | |
let txn = Txn::new(); | |
let cmp = Compare::value(key.clone(), CompareOp::Equal, initial_value_as_json.clone()); | |
let succ = TxnOp::put(key.clone(), initial_value_as_json, None); | |
let fail = TxnOp::get(key.clone(), None); | |
txn.when(vec![cmp]).and_then(vec![succ]).or_else(vec![fail]) | |
} | |
None => { | |
let initial_value_as_json = json!({"count": 1}).to_string(); | |
let txn = Txn::new(); | |
// note: version, not value here - we're checking that the key does not exist still | |
let cmp = Compare::version(key.clone(), CompareOp::Equal, 0); | |
let succ = TxnOp::put(key.clone(), initial_value_as_json, None); | |
let fail = TxnOp::get(key.clone(), None); | |
txn.when(vec![cmp]).and_then(vec![succ]).or_else(vec![fail]) | |
} | |
}; | |
let txn_resp = client | |
.txn(txn) | |
.await | |
.expect("got etcd response successfully"); | |
log::debug!("finished transaction: {:?}", txn_resp); | |
if txn_resp.succeeded() { | |
Ok(()) | |
} else { | |
Err("transaction was not successful".into()) | |
} | |
} | |
async fn determine_shard_assignment( | |
etcd: &AsyncMutex<Client>, | |
total_shards: usize, | |
replicas_per_shard: usize, | |
port: u16, | |
) -> Result<usize, Box<dyn std::error::Error>> { | |
log::info!("time to determine the shard"); | |
let mut client = etcd.lock().await; | |
// note: essentially does greedy placement, | |
// finding the first shard without a complete replica set | |
// and adding this task id to the replica set | |
for shard_id in 0..total_shards { | |
let key = format!("shard_counts/shard_{}", shard_id); | |
// note: between this read and the update, values can change | |
// this is why we can need to move on to the next shard | |
let response = client.get(key.clone(), None).await?; | |
let maybe_count = response.kvs().get(0).map(|kv| { | |
serde_json::from_slice::<ShardData>(&kv.value()) | |
.unwrap() | |
.count as usize | |
}); | |
let count = maybe_count.unwrap_or(0); | |
if count < replicas_per_shard { | |
// note: has to take mut client here or will deadlock with the above lock taken on etcd | |
match increment_shard_count(&mut client, &key, maybe_count).await { | |
Ok(_) => { | |
log::info!("adding replica to count for {}", shard_id); | |
// note: store this so we can use it to look up and decrement later on shutdown | |
let value = json!({ "shard_id": shard_id}).to_string(); | |
client.put(format!("task/{}", port), value, None).await?; | |
return Ok(shard_id); | |
} | |
Err(_) => { | |
log::info!( | |
"another task stole the assignment for {}, trying another shard", | |
shard_id | |
); | |
continue; | |
} | |
} | |
} | |
} | |
Err("No available shards".into()) | |
} | |
// not using this locally but just have it for later if I want | |
async fn fetch_config_from_codecommit( | |
make_outbound_request: bool, | |
default_number_shards: usize, | |
default_replicas_per_shard: usize, | |
) -> Result<(usize, usize), Box<dyn std::error::Error>> { | |
if make_outbound_request { | |
let region_provider = RegionProviderChain::default_provider().or_else("us-west-2"); | |
let config = aws_config::from_env().region(region_provider).load().await; | |
let client = codecommit::Client::new(&config); | |
let content = client | |
.get_file() | |
.repository_name("my-config-repo") | |
.file_path("config.json") | |
.send() | |
.await? | |
.file_content() | |
.as_ref() | |
.to_vec(); | |
let config_data: serde_json::Value = serde_json::from_slice(&content)?; | |
let total_shards = config_data["total_shards"] | |
.as_u64() | |
.expect("Expected total_shards") as usize; | |
let replicas_per_shard = config_data["replicas_per_shard"] | |
.as_i64() | |
.expect("Expected replicas_per_shard") as usize; | |
Ok((total_shards, replicas_per_shard)) | |
} else { | |
let total_shards = default_number_shards; | |
// these start at 1, that is, the primary is the first replica | |
let replicas_per_shard = default_replicas_per_shard; | |
Ok((total_shards, replicas_per_shard)) | |
} | |
} | |
struct AppState { | |
etcd: Arc<AsyncMutex<Client>>, | |
total_shards: usize, | |
replicas_per_shard: usize, | |
/// the fixed identity of the task | |
port: u16, | |
/// which type of node this is | |
role: String, | |
/// the dynamic and assigned identity of the shard | |
shard_id: Mutex<Option<usize>>, | |
} | |
impl fmt::Debug for AppState { | |
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
f.debug_struct("AppState") | |
.field("etcd", &"Arc<AsyncMutext<Client>>> { ... }") // Custom display for Inner | |
.field("total_shards", &self.total_shards) | |
.field("replicas_per_shard", &self.replicas_per_shard) | |
.field("port", &self.port) | |
.finish() | |
} | |
} | |
#[derive(Parser, Debug)] | |
#[command(version, about, long_about = None)] | |
struct Args { | |
/// Role of the cete node | |
#[arg(short, long)] | |
role: String, | |
/// total number distinct shards | |
#[arg(short, long, default_value_t = 1)] | |
shards: usize, | |
/// total number replicas per shard | |
#[arg(long, default_value_t = 1)] | |
replicas: usize, | |
} | |
#[tokio::main] | |
async fn main() -> std::io::Result<()> { | |
env_logger::init(); | |
let args = Args::parse(); | |
let (total_shards, replicas_per_shard) = | |
fetch_config_from_codecommit(false, args.shards, args.replicas) | |
.await | |
.expect("Failed to fetch configuration"); | |
let etcd_address = env::var("ETCD_ADDRESS").unwrap_or_else(|_| "http://localhost:2379".into()); | |
let client = Client::connect([etcd_address], None) | |
.await | |
.expect("connect to ectd"); | |
let (listener, port) = match args.role.as_str() { | |
// note: just easier to test with when we have at least one node with a fixed address | |
"front" => { | |
let port = 3000; | |
let listener = TcpListener::bind("0.0.0.0:3000").expect("Failed to bind to front port"); | |
(listener, port) | |
} | |
"inner" => { | |
let listener = | |
TcpListener::bind("0.0.0.0:0").expect("Failed to bind to inner, ephemeral port"); | |
let port = listener.local_addr().unwrap().port(); | |
(listener, port) | |
} | |
_ => { | |
todo!("no other roles") | |
} | |
}; | |
let etcd = Arc::new(AsyncMutex::new(client)); | |
let binding = etcd.clone(); | |
let shard_id = determine_shard_assignment(&binding, total_shards, replicas_per_shard, port) | |
.await | |
.expect("assigned_shard"); | |
let data = web::Data::new(AppState { | |
etcd: etcd.clone(), | |
total_shards, | |
replicas_per_shard, | |
port, | |
role: args.role, | |
shard_id: Mutex::new(Some(shard_id)), | |
}); | |
log::info!("AppState: {:?}", &data); | |
// note: actix doesn't have its own shutdown hook, so we listen for signals and do cleanup manually | |
tokio::spawn(async move { | |
let mut terminate = signal(SignalKind::terminate()).unwrap(); | |
let mut interrupt = signal(SignalKind::interrupt()).unwrap(); | |
tokio::select! { | |
_ = terminate.recv() => { | |
println!("Received SIGTERM signal, shutting down..."); | |
cleanup_assignment(&etcd, port).await; | |
} | |
_ = interrupt.recv() => { | |
println!("Received SIGINT signal, shutting down..."); | |
cleanup_assignment(&etcd, port).await; | |
} | |
} | |
}); | |
HttpServer::new(move || { | |
App::new() | |
.wrap(middleware::Logger::default()) | |
.app_data(data.clone()) | |
.route("/me", web::get().to(show_assignment)) | |
.route("/", web::get().to(show_assignments)) | |
.route("/tasks", web::get().to(show_tasks)) | |
.route("/redo", web::get().to(recalculate_assignment)) | |
}) | |
.listen(listener)? | |
.run() | |
.await | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment