Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update last_started when fetching tunnel #165

Merged
35 changes: 9 additions & 26 deletions worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,43 +116,26 @@ async fn get_tunnel_handler(
let tunnel_data: Option<TunnelData> = kv.get(&tunnel_name).json().await.unwrap();

match tunnel_data {
Some(tunnel_data) => {
// TODO: Update the last_started field to `now`.
Some(mut tunnel_data) => {
tunnel_data.last_started = worker::Date::now().as_millis();
kv.put(&tunnel_name, &tunnel_data)
.unwrap()
.execute()
.await
.unwrap();

return Json(tunnel_data);
}
None => {
let (tunnel_id, tunnel_secret) = tunnel::create_tunnel(
let tunnel_data = tunnel::create_tunnel(
&state.cloudflare.api_token,
&state.cloudflare.account_id,
&tunnel_name,
)
.await
.unwrap();

tunnel::create_dns_record(
&state.cloudflare.api_token,
&state.cloudflare.tunnel_zone_id,
&tunnel_id,
&tunnel_name,
)
.await
.unwrap();

let zone_domain = tunnel::get_zone_domain(
&state.cloudflare.api_token,
&state.cloudflare.tunnel_zone_id,
)
.await;

let tunnel_data = TunnelData {
account_id: state.cloudflare.account_id,
name: tunnel_name.clone(),
url: format!("https://{}.{}", &tunnel_name, &zone_domain),
id: tunnel_id,
secret: tunnel_secret,
last_started: worker::Date::now().as_millis(),
};

kv.put(&tunnel_name, &tunnel_data)
.unwrap()
.execute()
Expand Down
270 changes: 64 additions & 206 deletions worker/src/tunnel.rs
Original file line number Diff line number Diff line change
@@ -1,229 +1,87 @@
// TODO: Replace String errors for proper error Enum
use std::fmt::Display;

use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use crate::TunnelData;

use serde::{Deserialize, Serialize};
use worker::{console_error, console_log};

#[derive(Serialize, Deserialize, Debug)]
struct GetTunnelApiResponse {
result: Vec<TunnelResultItem>,
}

#[derive(Serialize, Deserialize, Debug)]
struct TunnelResultItem {
id: String,
name: String,
deleted_at: Option<String>,
}

#[derive(Serialize, Deserialize, Debug)]
struct TokenApiResponse {
result: String,
}

#[derive(Serialize, Deserialize, Debug)]
struct CreateTunnelRequest {
name: String,
tunnel_secret: String,
}

#[derive(Serialize, Deserialize, Debug)]
struct CreateDNSRecordResponse {
result: DNSRecord,
}

#[derive(Serialize, Deserialize, Debug)]
struct DNSRecord {
content: String,
name: String,
r#type: String,
proxied: bool,
}
#[derive(Serialize, Deserialize, Debug)]
struct CreateTunnelResponse {
result: TunnelResultItem,
}

#[derive(Serialize, Deserialize)]
struct Config {
url: String,
tunnel: String,
#[serde(rename = "credentials-file")]
credentials_file: String,
#[derive(Debug)]
pub enum Error {
CreateCloudflareTunnel(String),
CreateDNS(String),
FetchZone(String),
}

pub async fn _get_tunnel_id(
api_token: &str,
account_id: &str,
tunnel_name: &str,
) -> Result<Option<String>, String> {
let url = format!(
"https://api.cloudflare.com/client/v4/accounts/{}/cfd_tunnel",
account_id
);
let (client, headers) = prepare_client_and_headers(api_token)?;
let query_url = format!("{}?name={}", url, tunnel_name);

let parsed: GetTunnelApiResponse =
send_request(&client, &query_url, headers, None, "GET").await?;
if parsed.result.is_empty() {
Ok(None)
} else {
// Check if there exists a tunnel with this name that hasn't been deleted
match parsed
.result
.iter()
.find(|tunnel| tunnel.deleted_at.is_none())
{
Some(tunnel) => Ok(Some(tunnel.id.clone())),
None => Ok(None),
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::CreateCloudflareTunnel(text) => write!(f, "Failed to crate tunnel: {}", text),
Error::CreateDNS(text) => write!(f, "Failed to crate DNS record: {}", text),
Error::FetchZone(text) => write!(f, "Failed to fetch Zone details: {}", text),
}
}
}

impl std::error::Error for Error {}

pub async fn create_tunnel(
api_token: &str,
account_id: &str,
tunnel_name: &str,
// TODO: Make this tuple into a proper type
) -> Result<(String, String), String> {
let tunnel_secret = crate::generate_secret();
let url = format!(
"https://api.cloudflare.com/client/v4/accounts/{}/cfd_tunnel",
account_id,
);
let (client, headers) = prepare_client_and_headers(api_token)?;
let body = serde_json::to_string(&CreateTunnelRequest {
name: tunnel_name.to_string(),
tunnel_secret: tunnel_secret.clone(),
})
.expect("body to be valid");

let parsed: CreateTunnelResponse =
send_request(&client, &url, headers, Some(body), "POST").await?;

Ok((parsed.result.id, tunnel_secret))
}

pub async fn create_dns_record(
api_token: &str,
zone_id: &str,
tunnel_id: &str,
tunnel_name: &str,
) -> Result<(), String> {
let url = format!(
"https://api.cloudflare.com/client/v4/zones/{}/dns_records",
zone_id
);
let (client, headers) = prepare_client_and_headers(api_token)?;
let body = serde_json::to_string(&DNSRecord {
name: tunnel_name.to_string(),
content: format!("{}.cfargotunnel.com", tunnel_id),
r#type: "CNAME".to_string(),
proxied: true,
})
.expect("body to be valid");

let _parsed: CreateDNSRecordResponse =
send_request(&client, &url, headers, Some(body), "POST").await?;
Ok(())
}

pub async fn get_zone_domain(api_token: &str, zone_id: &str) -> String {
#[derive(Deserialize)]
struct ZoneResponseResult {
name: String,
}

#[derive(Deserialize)]
struct ZoneResponse {
result: ZoneResponseResult,
}
) -> Result<TunnelData, Error> {
let client = crate::cloudflare_client(api_token);
let tunnel_secret = crate::generate_secret();

let url = format!("https://api.cloudflare.com/client/v4/zones/{}", &zone_id);
let (client, headers) =
prepare_client_and_headers(api_token).expect("client to be proper built");
let create_tunnel_req = cloudflare::endpoints::cfd_tunnel::create_tunnel::CreateTunnel {
account_identifier: account_id,
params: cloudflare::endpoints::cfd_tunnel::create_tunnel::Params {
name: tunnel_name,
tunnel_secret: &tunnel_secret.as_bytes().to_vec(),
config_src: &cloudflare::endpoints::cfd_tunnel::ConfigurationSrc::Local,
metadata: None,
},
};

let zone_response: ZoneResponse = send_request(&client, &url, headers, None, "GET")
let tunnel = client
.request(&create_tunnel_req)
.await
.unwrap();

zone_response.result.name
}

// Helper to create an HTTP client and prepare headers
fn prepare_client_and_headers(api_token: &str) -> Result<(reqwest::Client, HeaderMap), String> {
// this should be a string, not a result
// let bearer_token = sys.get_env("LINKUP_CF_API_TOKEN")?;
let client = reqwest::Client::new();
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", api_token)).expect("api_token should be valid"),
);

Ok((client, headers))
}

// Helper for sending requests and handling responses
async fn send_request<T: for<'de> serde::Deserialize<'de>>(
client: &reqwest::Client,
url: &str,
headers: HeaderMap,
body: Option<String>,
method: &str,
) -> Result<T, String> {
console_log!(
"Sending request {} '{}' with body '{:?}'",
method,
url,
&body
);

let builder = match method {
"GET" => client.get(url),
"POST" => client.post(url),
_ => return Err("Invalid HTTP method".into()),
.map_err(|err| Error::CreateCloudflareTunnel(err.to_string()))?
.result;

let create_dns_req = cloudflare::endpoints::dns::CreateDnsRecord {
zone_identifier: zone_id,
params: cloudflare::endpoints::dns::CreateDnsRecordParams {
proxied: Some(true),
name: tunnel_name,
content: cloudflare::endpoints::dns::DnsContent::CNAME {
content: format!("{}.cfargotunnel.com", tunnel.id),
},
ttl: None,
priority: None,
},
};

let builder = builder.headers(headers);
let builder = if let Some(body) = body {
builder.body(body)
} else {
builder
};
client
.request(&create_dns_req)
.await
.map_err(|err| Error::CreateDNS(err.to_string()))?;

let response = builder.send().await.unwrap();
let status = response.status();
let get_zone_req = cloudflare::endpoints::zone::ZoneDetails {
identifier: zone_id,
};

if status.is_success() {
let response_body = response.text().await.unwrap();
console_log!(
"Response: status: {}; content: '{}'",
&status,
&response_body
);
let zone = client
.request(&get_zone_req)
.await
.map_err(|err| Error::FetchZone(err.to_string()))?
.result;

match serde_json::from_str(&response_body) {
Ok(val) => {
// console_log!("{}", val.clone());
Ok(val)
}
Err(e) => {
console_error!("{:?}", e);
Err("Wot 2".to_string())
}
}
} else {
let response_body = response.text().await.unwrap();
console_log!(
"Response: status: {}; content: '{}'",
&status,
&response_body
);
let tunnel_data = TunnelData {
account_id: account_id.to_string(),
name: tunnel_name.to_string(),
url: format!("https://{}.{}", &tunnel_name, &zone.name),
id: tunnel.id.to_string(),
secret: tunnel_secret,
last_started: worker::Date::now().as_millis(),
};

Err("Wot".into())
}
Ok(tunnel_data)
}