diff --git a/tree/tests/chgrp/mod.rs b/tree/tests/chgrp/mod.rs index e1267168..67975f93 100644 --- a/tree/tests/chgrp/mod.rs +++ b/tree/tests/chgrp/mod.rs @@ -15,7 +15,7 @@ use std::{ self, fs::{MetadataExt, PermissionsExt}, }, - sync::Once, + sync::{Once, RwLock}, thread, time::Duration, }; @@ -33,13 +33,21 @@ fn chgrp_test(args: &[&str], expected_output: &str, expected_error: &str, expect }); } -static INIT_GROUPS: Once = Once::new(); -static mut PRIMARY_GROUP: String = String::new(); -static mut SECONDARY_GROUP: String = String::new(); +#[derive(Clone)] +struct Groups { + primary_group: String, + secondary_group: String, + gid1: u32, + gid2: u32, +} -static mut GID1: u32 = 0; -static mut GID2: u32 = 0; -static INIT_GID: Once = Once::new(); +static INIT_GROUPS: Once = Once::new(); +static GROUPS: RwLock = RwLock::new(Groups { + primary_group: String::new(), + secondary_group: String::new(), + gid1: 0, + gid2: 0, +}); fn get_group_id(name: &CStr) -> u32 { unsafe { @@ -54,12 +62,22 @@ fn get_group_id(name: &CStr) -> u32 { // Return two groups that the current user belongs to. fn get_groups() -> ((String, u32), (String, u32)) { - // Linux - (primary group of current user, a group in the supplemental group list that - // is not the primary group) - // macOS - ("staff", "admin") - let (g1, g2) = if cfg!(target_os = "linux") { - unsafe { - INIT_GROUPS.call_once(|| { + // Guard the writes to the GROUPS with a `Once` + INIT_GROUPS.call_once(|| { + let mut groups = GROUPS.write().unwrap(); + let Groups { + primary_group, + secondary_group, + gid1, + gid2, + } = &mut *groups; + + // Initialize group strings + // Linux - (primary group of current user, a group in the supplemental group list that + // is not the primary group) + // macOS - ("staff", "admin") + if cfg!(target_os = "linux") { + unsafe { let uid = libc::getuid(); let pw = libc::getpwuid(uid); if pw.is_null() { @@ -73,7 +91,7 @@ fn get_groups() -> ((String, u32), (String, u32)) { } let gr_name = CStr::from_ptr((&*gr).gr_name).to_owned(); - PRIMARY_GROUP = gr_name.to_str().unwrap().to_owned(); + *primary_group = gr_name.to_str().unwrap().to_owned(); let mut count = libc::getgroups(0, std::ptr::null_mut()); if count < 0 { @@ -83,18 +101,9 @@ fn get_groups() -> ((String, u32), (String, u32)) { ); } - let mut groups_ptr: *mut libc::gid_t = - libc::malloc(std::mem::size_of::() * count as usize) - as *mut libc::gid_t; - - if groups_ptr.is_null() { - panic!( - "unable to allocate memory for groups list: {}", - io::Error::last_os_error() - ); - } + let mut groups_list: Vec = vec![0; count as usize]; - count = libc::getgroups(count, groups_ptr); + count = libc::getgroups(count, groups_list.as_mut_ptr()); match count { _ if count < 2 => panic!("user must be a member of at least two groups"), -1 => panic!( @@ -104,58 +113,63 @@ fn get_groups() -> ((String, u32), (String, u32)) { _ => {} } - for _ in 0..count { - if groups_ptr.is_null() { - panic!("unable to get second group: reached end of group list"); - } - + for second_gid in groups_list { // Skip over the primary_gid - if *groups_ptr == primary_gid { - groups_ptr = groups_ptr.offset(1); + if second_gid == primary_gid { continue; } else { - let second_gid = *groups_ptr; let sec_grent = libc::getgrgid(second_gid); if sec_grent.is_null() { panic!("Unable to get group entry for secondary group id {second_gid}"); } let sec_gr_name = CStr::from_ptr((&*sec_grent).gr_name).to_owned(); - SECONDARY_GROUP = sec_gr_name.to_str().unwrap().to_owned(); + *secondary_group = sec_gr_name.to_str().unwrap().to_owned(); break; } } - if SECONDARY_GROUP == "" { + + if secondary_group.is_empty() { panic!("unable to find suitable secondary group"); } - }); - - (PRIMARY_GROUP.clone(), SECONDARY_GROUP.clone()) + } + } else if cfg!(target_os = "macos") { + *primary_group = String::from("staff"); + *secondary_group = String::from("admin"); + } else { + panic!("Unsupported OS"); } - } else if cfg!(target_os = "macos") { - ("staff".to_owned(), "admin".to_owned()) - } else { - panic!("Unsupported OS") - }; - - unsafe { - INIT_GID.call_once(|| { - let g1_cstr = CString::new(g1.as_str()).unwrap(); - let g2_cstr = CString::new(g2.as_str()).unwrap(); - - GID1 = get_group_id(&g1_cstr); - GID2 = get_group_id(&g2_cstr); - }); - // Must be initialized - assert_ne!(GID1, 0); - assert_ne!(GID2, 0); + // Initialize the group IDs corresponding to the group strings + { + let g1_cstr = CString::new(primary_group.as_str()).unwrap(); + let g2_cstr = CString::new(secondary_group.as_str()).unwrap(); - // Must be different groups - assert_ne!(GID1, GID2); + *gid1 = get_group_id(&g1_cstr); + *gid2 = get_group_id(&g2_cstr); + } + }); - ((g1, GID1), (g2, GID2)) - } + // The reads to GROUPS should not have conflicts with the writes because: + // 1) `Once` will block until the initialization is finished. + // 2) The writes are only done inside `Once::call_once`. + let groups = GROUPS.read().unwrap(); + let groups_ref = &*groups; + let Groups { + primary_group, + secondary_group, + gid1, + gid2, + } = groups_ref.clone(); + + // Must be initialized + assert_ne!(gid1, 0); + assert_ne!(gid2, 0); + + // Must be different groups + assert_ne!(gid1, gid2); + + ((primary_group, gid1), (secondary_group, gid2)) } fn file_gid(path: &str) -> io::Result {