|
| 1 | +use std::{collections::HashMap, str::FromStr}; |
| 2 | +use syntect::easy::HighlightLines; |
| 3 | +use syntect::parsing::SyntaxSet; |
| 4 | +use syntect::highlighting::{ThemeSet, Style}; |
| 5 | +use syntect::util::{as_24_bit_terminal_escaped, LinesWithEndings}; |
| 6 | +use anyhow::{anyhow, Result}; |
| 7 | +use clap::{AppSettings, Clap}; |
| 8 | +use colored::*; |
| 9 | +use mime::Mime; |
| 10 | +use reqwest::{Client, header, Response, Url}; |
| 11 | + |
| 12 | +// 以下部分用于处理 CLI |
| 13 | + |
| 14 | +// 定义 HTTPie 的 CLI 的主入口,它包含若干个子命令 |
| 15 | +// 下面 /// 的注释是文档,clap 会将其作为 CLI 的帮助 |
| 16 | + |
| 17 | +/// A naive httpie implementation with Rust, can you imagine how easy it is? |
| 18 | +#[derive(Clap, Debug)] |
| 19 | +#[clap(version = "1.0", author = "Custer<[email protected]")] |
| 20 | +#[clap(setting = AppSettings::ColoredHelp)] |
| 21 | +struct Opts { |
| 22 | + #[clap(subcommand)] |
| 23 | + subcmd: SubCommand, |
| 24 | +} |
| 25 | + |
| 26 | +// 子命令分别对应不同的 HTTP 方法,目前只支持 get/post |
| 27 | +#[derive(Clap, Debug)] |
| 28 | +enum SubCommand { |
| 29 | + Get(Get), |
| 30 | + Post(Post), |
| 31 | + // 暂且不支持其他 HTTP 方法 |
| 32 | +} |
| 33 | + |
| 34 | +// get 子命令 |
| 35 | + |
| 36 | +/// feed get with an url and we will retrieve the response for you |
| 37 | +#[derive(Clap, Debug)] |
| 38 | +struct Get { |
| 39 | + /// HTTP 请求的 URL |
| 40 | + #[clap(parse(try_from_str = parse_url))] |
| 41 | + url: String, |
| 42 | +} |
| 43 | + |
| 44 | +fn parse_url(s: &str) -> Result<String> { |
| 45 | + // 这里仅仅检查一下 URL 是否合法 |
| 46 | + let _url: Url = s.parse()?; |
| 47 | + Ok(s.into()) |
| 48 | +} |
| 49 | + |
| 50 | +// post 子命令,需要输入一个 URL,和若干个可选的 key=value,用于提供 json body |
| 51 | + |
| 52 | +/// feed post with an url and optional key=value pairs. We will post the data |
| 53 | +/// as JSON, and retrieve the response for you |
| 54 | +#[derive(Clap, Debug)] |
| 55 | +struct Post { |
| 56 | + /// HTTP 请求的 URL |
| 57 | + #[clap(parse(try_from_str = parse_url))] |
| 58 | + url: String, |
| 59 | + /// HTTP 请求的 body |
| 60 | + #[clap(parse(try_from_str = parse_kv_pair))] |
| 61 | + body: Vec<KvPair>, |
| 62 | +} |
| 63 | + |
| 64 | +/// 命令行中的 key=value 可以通过 parse_kv_pair 解析成 KvPair 结构 |
| 65 | +#[derive(Debug, PartialEq)] |
| 66 | +struct KvPair { |
| 67 | + k: String, |
| 68 | + v: String, |
| 69 | +} |
| 70 | + |
| 71 | +/// 当我们实现 FromStr trait 后,可以用 str.parse() 方法将字符串解析成 KvPair |
| 72 | +impl FromStr for KvPair { |
| 73 | + type Err = anyhow::Error; |
| 74 | + |
| 75 | + fn from_str(s: &str) -> Result<Self, Self::Err> { |
| 76 | + // 使用 = 进行 split,这会得到一个迭代器 |
| 77 | + let mut split = s.split("="); |
| 78 | + let err = || anyhow!(format!("Failed to parse {}", s)); |
| 79 | + Ok(Self { |
| 80 | + // 从迭代器中取第一个结果作为 key,迭代器返回 Some(T)/None |
| 81 | + // 我们将其转换成 Ok(T)/Err(E),然后用 ? 处理错误 |
| 82 | + k: (split.next().ok_or_else(err)?).to_string(), |
| 83 | + // 从迭代器中取第二个结果作为 value |
| 84 | + v: (split.next().ok_or_else(err)?).to_string(), |
| 85 | + }) |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +/// 因为我们为 KvPair 实现了 FromStr,这里可以直接 s.parse() 得到 KvPair |
| 90 | +fn parse_kv_pair(s: &str) -> Result<KvPair> { |
| 91 | + Ok(s.parse()?) |
| 92 | +} |
| 93 | + |
| 94 | +/// 处理 get 子命令 |
| 95 | +async fn get(client: Client, args: &Get) -> Result<()> { |
| 96 | + let resp = client.get(&args.url).send().await?; |
| 97 | + Ok(print_resp(resp).await?) |
| 98 | +} |
| 99 | + |
| 100 | +/// 处理 post 子命令 |
| 101 | +async fn post(client: Client, args: &Post) -> Result<()> { |
| 102 | + let mut body = HashMap::new(); |
| 103 | + for pair in args.body.iter() { |
| 104 | + body.insert(&pair.k, &pair.v); |
| 105 | + } |
| 106 | + let resp = client.post(&args.url).json(&body).send().await?; |
| 107 | + Ok(print_resp(resp).await?) |
| 108 | +} |
| 109 | + |
| 110 | +// 打印服务器版本号 + 状态码 |
| 111 | +fn print_status(resp: &Response) { |
| 112 | + let status = format!("{:?} {}", resp.version(), resp.status()).blue(); |
| 113 | + println!("{}\n", status); |
| 114 | +} |
| 115 | + |
| 116 | +// 打印服务端返回的 HTTP header |
| 117 | +fn print_headers(resp: &Response) { |
| 118 | + for (name, value) in resp.headers() { |
| 119 | + println!("{}: {:?}", name.to_string().green(), value); |
| 120 | + } |
| 121 | + print!("\n"); |
| 122 | +} |
| 123 | + |
| 124 | +/// 打印服务器返回的 HTTP body |
| 125 | +fn print_body(m: Option<Mime>, body: &String) { |
| 126 | + match m { |
| 127 | + // 对于 "application/json" 我们 pretty print |
| 128 | + Some(v) if v == mime::APPLICATION_JSON => { |
| 129 | + // println!("{}", jsonxf::pretty_print(body).unwrap().cyan()) |
| 130 | + print_syntect(body); |
| 131 | + } |
| 132 | + // 其他 mime type,我们就直接输出 |
| 133 | + _ => println!("{}", body), |
| 134 | + } |
| 135 | +} |
| 136 | + |
| 137 | +fn print_syntect(s: &str) { |
| 138 | + // Load these once at the start of your program |
| 139 | + let ps = SyntaxSet::load_defaults_newlines(); |
| 140 | + let ts = ThemeSet::load_defaults(); |
| 141 | + let syntax = ps.find_syntax_by_extension("json").unwrap(); |
| 142 | + let mut h = HighlightLines::new(syntax, &ts.themes["base16-ocean.dark"]); |
| 143 | + for line in LinesWithEndings::from(s) { |
| 144 | + let ranges: Vec<(Style, &str)> = h.highlight(line, &ps); |
| 145 | + let escaped = as_24_bit_terminal_escaped(&ranges[..], true); |
| 146 | + println!("{}", escaped); |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +/// 打印整个响应 |
| 151 | +async fn print_resp(resp: Response) -> Result<()> { |
| 152 | + print_status(&resp); |
| 153 | + print_headers(&resp); |
| 154 | + let mime = get_context_type(&resp); |
| 155 | + let body = resp.text().await?; |
| 156 | + print_body(mime, &body); |
| 157 | + Ok(()) |
| 158 | +} |
| 159 | + |
| 160 | +/// 将服务器返回的 content-type 解析成 Mime 类型 |
| 161 | +fn get_context_type(resp: &Response) -> Option<Mime> { |
| 162 | + resp.headers() |
| 163 | + .get(header::CONTENT_TYPE) |
| 164 | + .map(|v| v.to_str().unwrap().parse().unwrap()) |
| 165 | +} |
| 166 | + |
| 167 | +/// 程序的入口函数,因为在 HTTP 请求时我们使用了异步处理,所以这里引入 tokio |
| 168 | +#[tokio::main] |
| 169 | +async fn main() -> Result<()> { |
| 170 | + let opts: Opts = Opts::parse(); |
| 171 | + let mut headers = header::HeaderMap::new(); |
| 172 | + // 为我们的 HTTP 客户端添加一些缺省的 HTTP 头 |
| 173 | + headers.insert("X-POWERED-BY", "Rust".parse()?); |
| 174 | + headers.insert(header::USER_AGENT, "Rust Httpie".parse()?); |
| 175 | + // 生成一个 HTTP 客户端 |
| 176 | + let client = reqwest::Client::builder() |
| 177 | + .default_headers(headers) |
| 178 | + .build()?; |
| 179 | + let result = match opts.subcmd { |
| 180 | + SubCommand::Get(ref args) => get(client, args).await?, |
| 181 | + SubCommand::Post(ref args) => post(client, args).await?, |
| 182 | + }; |
| 183 | + Ok(result) |
| 184 | +} |
| 185 | + |
| 186 | +// 仅在 cargo test 时才编译 |
| 187 | +#[cfg(test)] |
| 188 | +mod tests { |
| 189 | + use super::*; |
| 190 | + |
| 191 | + #[test] |
| 192 | + fn parse_url_works() { |
| 193 | + assert!(parse_url("abc").is_err()); |
| 194 | + assert!(parse_url("http://abc.xyz").is_ok()); |
| 195 | + assert!(parse_url("https://httpbin.org/post").is_ok()); |
| 196 | + } |
| 197 | + |
| 198 | + #[test] |
| 199 | + fn parse_kv_pair_works() { |
| 200 | + assert!(parse_kv_pair("a").is_err()); |
| 201 | + assert_eq!( |
| 202 | + parse_kv_pair("a=1").unwrap(), |
| 203 | + KvPair { |
| 204 | + k: "a".into(), |
| 205 | + v: "1".into() |
| 206 | + } |
| 207 | + ); |
| 208 | + assert_eq!( |
| 209 | + parse_kv_pair("b=").unwrap(), |
| 210 | + KvPair { |
| 211 | + k: "b".into(), |
| 212 | + v: "".into() |
| 213 | + } |
| 214 | + ); |
| 215 | + } |
| 216 | +} |
0 commit comments