diff --git a/Cargo.lock b/Cargo.lock index 226dbf4..33e8907 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -122,6 +122,7 @@ dependencies = [ "csv", "env_logger", "indexmap", + "indicatif", "log", "regex", "reqwest", @@ -184,6 +185,19 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "csv" version = "1.4.0" @@ -216,6 +230,12 @@ dependencies = [ "syn", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "env_filter" version = "0.1.4" @@ -562,6 +582,19 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -665,6 +698,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "once_cell" version = "1.21.3" @@ -1250,6 +1289,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "untrusted" version = "0.9.0" @@ -1406,6 +1451,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" diff --git a/Cargo.toml b/Cargo.toml index 23bfef9..6b57c03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,4 @@ csv = "1.3.0" serde_json = { version = "1.0", features = ["preserve_order"] } indexmap = "2.12" clap_complete = "4.5" +indicatif = "0.17" diff --git a/examples/download_test.rs b/examples/download_test.rs index 66fd2b2..63baf80 100644 --- a/examples/download_test.rs +++ b/examples/download_test.rs @@ -4,7 +4,7 @@ use cfspeedtest::OutputFormat; fn main() { println!("Testing download speed with 10MB of payload"); - let download_speed = test_download( + let (download_speed, _) = test_download( &reqwest::blocking::Client::new(), 10_000_000, OutputFormat::None, // don't write to stdout while running the test diff --git a/src/progress.rs b/src/progress.rs index 9c6d150..7889857 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -1,15 +1,31 @@ -use std::io::stdout; -use std::io::Write; - -pub fn print_progress(name: &str, curr: u32, max: u32) { - const BAR_LEN: u32 = 30; - let progress_line = ((curr as f32 / max as f32) * BAR_LEN as f32) as u32; - let remaining_line = BAR_LEN - progress_line; - print!( - "\r{:<15} [{}{}]", - name, - (0..progress_line).map(|_| "=").collect::(), - (0..remaining_line).map(|_| "-").collect::(), - ); - stdout().flush().expect("error printing progress bar"); +use indicatif::{ProgressBar, ProgressStyle}; + +pub struct Progress { + bar: ProgressBar, +} + +impl Progress { + pub fn new(name: &str, max: u32) -> Self { + let bar = ProgressBar::new(max as u64); + bar.set_style( + ProgressStyle::default_bar() + .template("{prefix:<15} [{bar:30}] {msg}") + .unwrap() + .progress_chars("=-"), + ); + bar.set_prefix(name.to_string()); + Progress { bar } + } + + pub fn set_position(&self, curr: u32) { + self.bar.set_position(curr as u64); + } + + pub fn finish(&self) { + self.bar.finish(); + } + + pub fn set_message(&self, msg: impl Into>) { + self.bar.set_message(msg); + } } diff --git a/src/speedtest.rs b/src/speedtest.rs index 601e3bb..6505c51 100644 --- a/src/speedtest.rs +++ b/src/speedtest.rs @@ -2,12 +2,12 @@ use crate::measurements::format_bytes; use crate::measurements::log_measurements; use crate::measurements::LatencyMeasurement; use crate::measurements::Measurement; -use crate::progress::print_progress; +use crate::progress::Progress; use crate::OutputFormat; use crate::SpeedTestCLIOptions; use log; use regex::Regex; -use reqwest::{blocking::Client, StatusCode}; +use reqwest::blocking::Client; use serde::Serialize; use std::{ fmt::Display, @@ -160,18 +160,29 @@ pub fn run_latency_test( output_format: OutputFormat, ) -> (Vec, f64) { let mut measurements: Vec = Vec::new(); + let progress = if output_format == OutputFormat::StdOut { + Some(Progress::new("latency test", nr_latency_tests)) + } else { + None + }; + for i in 0..nr_latency_tests { - if output_format == OutputFormat::StdOut { - print_progress("latency test", i + 1, nr_latency_tests); + if let Some(ref pb) = progress { + pb.set_position(i + 1); } let latency = test_latency(client); measurements.push(latency); } + + if let Some(pb) = progress { + pb.finish(); + } + let avg_latency = measurements.iter().sum::() / measurements.len() as f64; if output_format == OutputFormat::StdOut { println!( - "\nAvg GET request latency {avg_latency:.2} ms (RTT excluding server processing time)\n" + "Avg GET request latency {avg_latency:.2} ms (RTT excluding server processing time)\n" ); } (measurements, avg_latency) @@ -222,7 +233,7 @@ const TIME_THRESHOLD: Duration = Duration::from_secs(5); pub fn run_tests( client: &Client, - test_fn: fn(&Client, usize, OutputFormat) -> f64, + test_fn: fn(&Client, usize, OutputFormat) -> (f64, Duration), test_type: TestType, payload_sizes: Vec, nr_tests: u32, @@ -233,29 +244,41 @@ pub fn run_tests( for payload_size in payload_sizes { log::debug!("running tests for payload_size {payload_size}"); let start = Instant::now(); + + let progress = if output_format == OutputFormat::StdOut { + Some(Progress::new( + &format!("{:?} {:<5}", test_type, format_bytes(payload_size)), + nr_tests, + )) + } else { + None + }; + for i in 0..nr_tests { - if output_format == OutputFormat::StdOut { - print_progress( - &format!("{:?} {:<5}", test_type, format_bytes(payload_size)), - i, - nr_tests, + let (mbit, duration) = test_fn(client, payload_size, output_format); + + if let Some(ref pb) = progress { + pb.set_position(i + 1); + let message = format!( + " {:>6.2} mbit/s | {:>5} in {:>4}ms", + mbit, + format_bytes(payload_size), + duration.as_millis() ); + pb.set_message(message); } - let mbit = test_fn(client, payload_size, output_format); measurements.push(Measurement { test_type, payload_size, mbit, }); } - if output_format == OutputFormat::StdOut { - print_progress( - &format!("{:?} {:<5}", test_type, format_bytes(payload_size)), - nr_tests, - nr_tests, - ); + + if let Some(pb) = progress { + pb.finish(); println!() } + let duration = start.elapsed(); // only check TIME_THRESHOLD if dynamic max payload sizing is not disabled @@ -267,62 +290,43 @@ pub fn run_tests( measurements } -pub fn test_upload(client: &Client, payload_size_bytes: usize, output_format: OutputFormat) -> f64 { +pub fn test_upload( + client: &Client, + payload_size_bytes: usize, + _output_format: OutputFormat, +) -> (f64, Duration) { let url = &format!("{BASE_URL}/{UPLOAD_URL}"); let payload: Vec = vec![1; payload_size_bytes]; let req_builder = client.post(url).body(payload); - let (mut response, status_code, mbits, duration) = { + let (mut response, mbits, duration) = { let start = Instant::now(); let response = req_builder.send().expect("failed to get response"); - let status_code = response.status(); let duration = start.elapsed(); let mbits = (payload_size_bytes as f64 * 8.0 / 1_000_000.0) / duration.as_secs_f64(); - (response, status_code, mbits, duration) + (response, mbits, duration) }; // Drain response after timing so we don't skew upload measurement. let _ = std::io::copy(&mut response, &mut std::io::sink()); - if output_format == OutputFormat::StdOut { - print_current_speed(mbits, duration, status_code, payload_size_bytes); - } - mbits + (mbits, duration) } pub fn test_download( client: &Client, payload_size_bytes: usize, - output_format: OutputFormat, -) -> f64 { + _output_format: OutputFormat, +) -> (f64, Duration) { let url = &format!("{BASE_URL}/{DOWNLOAD_URL}{payload_size_bytes}"); let req_builder = client.get(url); - let (status_code, mbits, duration) = { + let (mbits, duration) = { let start = Instant::now(); let mut response = req_builder.send().expect("failed to get response"); - let status_code = response.status(); // Stream the body to avoid buffering the full payload in memory. let _ = std::io::copy(&mut response, &mut std::io::sink()); let duration = start.elapsed(); let mbits = (payload_size_bytes as f64 * 8.0 / 1_000_000.0) / duration.as_secs_f64(); - (status_code, mbits, duration) + (mbits, duration) }; - if output_format == OutputFormat::StdOut { - print_current_speed(mbits, duration, status_code, payload_size_bytes); - } - mbits -} - -fn print_current_speed( - mbits: f64, - duration: Duration, - status_code: StatusCode, - payload_size_bytes: usize, -) { - print!( - " {:>6.2} mbit/s | {:>5} in {:>4}ms -> status: {} ", - mbits, - format_bytes(payload_size_bytes), - duration.as_millis(), - status_code - ); + (mbits, duration) } pub fn fetch_metadata(client: &Client) -> Result {