diff --git a/crates/proc_macro_srv/src/lib.rs b/crates/proc_macro_srv/src/lib.rs index 1698bc967d..52693547e5 100644 --- a/crates/proc_macro_srv/src/lib.rs +++ b/crates/proc_macro_srv/src/lib.rs @@ -16,7 +16,9 @@ mod abis; use std::{ collections::{hash_map::Entry, HashMap}, - env, fs, + env, + ffi::OsString, + fs, path::{Path, PathBuf}, time::SystemTime, }; @@ -38,9 +40,8 @@ impl ProcMacroSrv { PanicMessage(format!("failed to load macro: {}", err)) })?; - let mut prev_env = HashMap::new(); + let prev_env = EnvSnapshot::new(); for (k, v) in &task.env { - prev_env.insert(k.as_str(), env::var_os(k)); env::set_var(k, v); } let prev_working_dir = match task.current_dir { @@ -60,12 +61,8 @@ impl ProcMacroSrv { .expand(&task.macro_name, ¯o_body, attributes.as_ref()) .map(|it| FlatTree::new(&it)); - for (k, _) in &task.env { - match &prev_env[k.as_str()] { - Some(v) => env::set_var(k, v), - None => env::remove_var(k), - } - } + prev_env.rollback(); + if let Some(dir) = prev_working_dir { if let Err(err) = std::env::set_current_dir(&dir) { eprintln!( @@ -101,6 +98,32 @@ impl ProcMacroSrv { } } +struct EnvSnapshot { + vars: HashMap, +} + +impl EnvSnapshot { + fn new() -> EnvSnapshot { + EnvSnapshot { vars: env::vars_os().collect() } + } + + fn rollback(self) { + let mut old_vars = self.vars; + for (name, value) in env::vars_os() { + let old_value = old_vars.remove(&name); + if old_value != Some(value) { + match old_value { + None => env::remove_var(name), + Some(old_value) => env::set_var(name, old_value), + } + } + } + for (name, old_value) in old_vars { + env::set_var(name, old_value) + } + } +} + pub mod cli; #[cfg(test)]