Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove static mut in chgrp tests and fix memleak inside get_groups #386

Merged
merged 1 commit into from
Feb 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 73 additions & 59 deletions tree/tests/chgrp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::{
self,
fs::{MetadataExt, PermissionsExt},
},
sync::Once,
sync::{Once, RwLock},
thread,
time::Duration,
};
Expand All @@ -33,13 +33,21 @@ fn chgrp_test(args: &[&str], expected_output: &str, expected_error: &str, expect
});
}

static INIT_GROUPS: Once = Once::new();
static mut PRIMARY_GROUP: String = String::new();
static mut SECONDARY_GROUP: String = String::new();
#[derive(Clone)]
struct Groups {
primary_group: String,
secondary_group: String,
gid1: u32,
gid2: u32,
}

static mut GID1: u32 = 0;
static mut GID2: u32 = 0;
static INIT_GID: Once = Once::new();
static INIT_GROUPS: Once = Once::new();
static GROUPS: RwLock<Groups> = RwLock::new(Groups {
primary_group: String::new(),
secondary_group: String::new(),
gid1: 0,
gid2: 0,
});

fn get_group_id(name: &CStr) -> u32 {
unsafe {
Expand All @@ -54,12 +62,22 @@ fn get_group_id(name: &CStr) -> u32 {

// Return two groups that the current user belongs to.
fn get_groups() -> ((String, u32), (String, u32)) {
// Linux - (primary group of current user, a group in the supplemental group list that
// is not the primary group)
// macOS - ("staff", "admin")
let (g1, g2) = if cfg!(target_os = "linux") {
unsafe {
INIT_GROUPS.call_once(|| {
// Guard the writes to the GROUPS with a `Once`
INIT_GROUPS.call_once(|| {
let mut groups = GROUPS.write().unwrap();
let Groups {
primary_group,
secondary_group,
gid1,
gid2,
} = &mut *groups;

// Initialize group strings
// Linux - (primary group of current user, a group in the supplemental group list that
// is not the primary group)
// macOS - ("staff", "admin")
if cfg!(target_os = "linux") {
unsafe {
let uid = libc::getuid();
let pw = libc::getpwuid(uid);
if pw.is_null() {
Expand All @@ -73,7 +91,7 @@ fn get_groups() -> ((String, u32), (String, u32)) {
}

let gr_name = CStr::from_ptr((&*gr).gr_name).to_owned();
PRIMARY_GROUP = gr_name.to_str().unwrap().to_owned();
*primary_group = gr_name.to_str().unwrap().to_owned();

let mut count = libc::getgroups(0, std::ptr::null_mut());
if count < 0 {
Expand All @@ -83,18 +101,9 @@ fn get_groups() -> ((String, u32), (String, u32)) {
);
}

let mut groups_ptr: *mut libc::gid_t =
libc::malloc(std::mem::size_of::<libc::gid_t>() * count as usize)
as *mut libc::gid_t;

if groups_ptr.is_null() {
panic!(
"unable to allocate memory for groups list: {}",
io::Error::last_os_error()
);
}
let mut groups_list: Vec<libc::gid_t> = vec![0; count as usize];

count = libc::getgroups(count, groups_ptr);
count = libc::getgroups(count, groups_list.as_mut_ptr());
match count {
_ if count < 2 => panic!("user must be a member of at least two groups"),
-1 => panic!(
Expand All @@ -104,58 +113,63 @@ fn get_groups() -> ((String, u32), (String, u32)) {
_ => {}
}

for _ in 0..count {
if groups_ptr.is_null() {
panic!("unable to get second group: reached end of group list");
}

for second_gid in groups_list {
// Skip over the primary_gid
if *groups_ptr == primary_gid {
groups_ptr = groups_ptr.offset(1);
if second_gid == primary_gid {
continue;
} else {
let second_gid = *groups_ptr;
let sec_grent = libc::getgrgid(second_gid);
if sec_grent.is_null() {
panic!("Unable to get group entry for secondary group id {second_gid}");
}

let sec_gr_name = CStr::from_ptr((&*sec_grent).gr_name).to_owned();
SECONDARY_GROUP = sec_gr_name.to_str().unwrap().to_owned();
*secondary_group = sec_gr_name.to_str().unwrap().to_owned();
break;
}
}
if SECONDARY_GROUP == "" {

if secondary_group.is_empty() {
panic!("unable to find suitable secondary group");
}
});

(PRIMARY_GROUP.clone(), SECONDARY_GROUP.clone())
}
} else if cfg!(target_os = "macos") {
*primary_group = String::from("staff");
*secondary_group = String::from("admin");
} else {
panic!("Unsupported OS");
}
} else if cfg!(target_os = "macos") {
("staff".to_owned(), "admin".to_owned())
} else {
panic!("Unsupported OS")
};

unsafe {
INIT_GID.call_once(|| {
let g1_cstr = CString::new(g1.as_str()).unwrap();
let g2_cstr = CString::new(g2.as_str()).unwrap();

GID1 = get_group_id(&g1_cstr);
GID2 = get_group_id(&g2_cstr);
});

// Must be initialized
assert_ne!(GID1, 0);
assert_ne!(GID2, 0);
// Initialize the group IDs corresponding to the group strings
{
let g1_cstr = CString::new(primary_group.as_str()).unwrap();
let g2_cstr = CString::new(secondary_group.as_str()).unwrap();

// Must be different groups
assert_ne!(GID1, GID2);
*gid1 = get_group_id(&g1_cstr);
*gid2 = get_group_id(&g2_cstr);
}
});

((g1, GID1), (g2, GID2))
}
// The reads to GROUPS should not have conflicts with the writes because:
// 1) `Once` will block until the initialization is finished.
// 2) The writes are only done inside `Once::call_once`.
let groups = GROUPS.read().unwrap();
let groups_ref = &*groups;
let Groups {
primary_group,
secondary_group,
gid1,
gid2,
} = groups_ref.clone();

// Must be initialized
assert_ne!(gid1, 0);
assert_ne!(gid2, 0);

// Must be different groups
assert_ne!(gid1, gid2);

((primary_group, gid1), (secondary_group, gid2))
}

fn file_gid(path: &str) -> io::Result<u32> {
Expand Down