Skip to content

Commit 0c321b0

Browse files
authored
Merge pull request #386 from JRF63/fix-chgrp-tests
Remove `static mut` in chgrp tests and fix memleak inside `get_groups`
2 parents c7d9602 + 0ae6b2d commit 0c321b0

File tree

1 file changed

+73
-59
lines changed

1 file changed

+73
-59
lines changed

tree/tests/chgrp/mod.rs

+73-59
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use std::{
1515
self,
1616
fs::{MetadataExt, PermissionsExt},
1717
},
18-
sync::Once,
18+
sync::{Once, RwLock},
1919
thread,
2020
time::Duration,
2121
};
@@ -33,13 +33,21 @@ fn chgrp_test(args: &[&str], expected_output: &str, expected_error: &str, expect
3333
});
3434
}
3535

36-
static INIT_GROUPS: Once = Once::new();
37-
static mut PRIMARY_GROUP: String = String::new();
38-
static mut SECONDARY_GROUP: String = String::new();
36+
#[derive(Clone)]
37+
struct Groups {
38+
primary_group: String,
39+
secondary_group: String,
40+
gid1: u32,
41+
gid2: u32,
42+
}
3943

40-
static mut GID1: u32 = 0;
41-
static mut GID2: u32 = 0;
42-
static INIT_GID: Once = Once::new();
44+
static INIT_GROUPS: Once = Once::new();
45+
static GROUPS: RwLock<Groups> = RwLock::new(Groups {
46+
primary_group: String::new(),
47+
secondary_group: String::new(),
48+
gid1: 0,
49+
gid2: 0,
50+
});
4351

4452
fn get_group_id(name: &CStr) -> u32 {
4553
unsafe {
@@ -54,12 +62,22 @@ fn get_group_id(name: &CStr) -> u32 {
5462

5563
// Return two groups that the current user belongs to.
5664
fn get_groups() -> ((String, u32), (String, u32)) {
57-
// Linux - (primary group of current user, a group in the supplemental group list that
58-
// is not the primary group)
59-
// macOS - ("staff", "admin")
60-
let (g1, g2) = if cfg!(target_os = "linux") {
61-
unsafe {
62-
INIT_GROUPS.call_once(|| {
65+
// Guard the writes to the GROUPS with a `Once`
66+
INIT_GROUPS.call_once(|| {
67+
let mut groups = GROUPS.write().unwrap();
68+
let Groups {
69+
primary_group,
70+
secondary_group,
71+
gid1,
72+
gid2,
73+
} = &mut *groups;
74+
75+
// Initialize group strings
76+
// Linux - (primary group of current user, a group in the supplemental group list that
77+
// is not the primary group)
78+
// macOS - ("staff", "admin")
79+
if cfg!(target_os = "linux") {
80+
unsafe {
6381
let uid = libc::getuid();
6482
let pw = libc::getpwuid(uid);
6583
if pw.is_null() {
@@ -73,7 +91,7 @@ fn get_groups() -> ((String, u32), (String, u32)) {
7391
}
7492

7593
let gr_name = CStr::from_ptr((&*gr).gr_name).to_owned();
76-
PRIMARY_GROUP = gr_name.to_str().unwrap().to_owned();
94+
*primary_group = gr_name.to_str().unwrap().to_owned();
7795

7896
let mut count = libc::getgroups(0, std::ptr::null_mut());
7997
if count < 0 {
@@ -83,18 +101,9 @@ fn get_groups() -> ((String, u32), (String, u32)) {
83101
);
84102
}
85103

86-
let mut groups_ptr: *mut libc::gid_t =
87-
libc::malloc(std::mem::size_of::<libc::gid_t>() * count as usize)
88-
as *mut libc::gid_t;
89-
90-
if groups_ptr.is_null() {
91-
panic!(
92-
"unable to allocate memory for groups list: {}",
93-
io::Error::last_os_error()
94-
);
95-
}
104+
let mut groups_list: Vec<libc::gid_t> = vec![0; count as usize];
96105

97-
count = libc::getgroups(count, groups_ptr);
106+
count = libc::getgroups(count, groups_list.as_mut_ptr());
98107
match count {
99108
_ if count < 2 => panic!("user must be a member of at least two groups"),
100109
-1 => panic!(
@@ -104,58 +113,63 @@ fn get_groups() -> ((String, u32), (String, u32)) {
104113
_ => {}
105114
}
106115

107-
for _ in 0..count {
108-
if groups_ptr.is_null() {
109-
panic!("unable to get second group: reached end of group list");
110-
}
111-
116+
for second_gid in groups_list {
112117
// Skip over the primary_gid
113-
if *groups_ptr == primary_gid {
114-
groups_ptr = groups_ptr.offset(1);
118+
if second_gid == primary_gid {
115119
continue;
116120
} else {
117-
let second_gid = *groups_ptr;
118121
let sec_grent = libc::getgrgid(second_gid);
119122
if sec_grent.is_null() {
120123
panic!("Unable to get group entry for secondary group id {second_gid}");
121124
}
122125

123126
let sec_gr_name = CStr::from_ptr((&*sec_grent).gr_name).to_owned();
124-
SECONDARY_GROUP = sec_gr_name.to_str().unwrap().to_owned();
127+
*secondary_group = sec_gr_name.to_str().unwrap().to_owned();
125128
break;
126129
}
127130
}
128-
if SECONDARY_GROUP == "" {
131+
132+
if secondary_group.is_empty() {
129133
panic!("unable to find suitable secondary group");
130134
}
131-
});
132-
133-
(PRIMARY_GROUP.clone(), SECONDARY_GROUP.clone())
135+
}
136+
} else if cfg!(target_os = "macos") {
137+
*primary_group = String::from("staff");
138+
*secondary_group = String::from("admin");
139+
} else {
140+
panic!("Unsupported OS");
134141
}
135-
} else if cfg!(target_os = "macos") {
136-
("staff".to_owned(), "admin".to_owned())
137-
} else {
138-
panic!("Unsupported OS")
139-
};
140-
141-
unsafe {
142-
INIT_GID.call_once(|| {
143-
let g1_cstr = CString::new(g1.as_str()).unwrap();
144-
let g2_cstr = CString::new(g2.as_str()).unwrap();
145-
146-
GID1 = get_group_id(&g1_cstr);
147-
GID2 = get_group_id(&g2_cstr);
148-
});
149142

150-
// Must be initialized
151-
assert_ne!(GID1, 0);
152-
assert_ne!(GID2, 0);
143+
// Initialize the group IDs corresponding to the group strings
144+
{
145+
let g1_cstr = CString::new(primary_group.as_str()).unwrap();
146+
let g2_cstr = CString::new(secondary_group.as_str()).unwrap();
153147

154-
// Must be different groups
155-
assert_ne!(GID1, GID2);
148+
*gid1 = get_group_id(&g1_cstr);
149+
*gid2 = get_group_id(&g2_cstr);
150+
}
151+
});
156152

157-
((g1, GID1), (g2, GID2))
158-
}
153+
// The reads to GROUPS should not have conflicts with the writes because:
154+
// 1) `Once` will block until the initialization is finished.
155+
// 2) The writes are only done inside `Once::call_once`.
156+
let groups = GROUPS.read().unwrap();
157+
let groups_ref = &*groups;
158+
let Groups {
159+
primary_group,
160+
secondary_group,
161+
gid1,
162+
gid2,
163+
} = groups_ref.clone();
164+
165+
// Must be initialized
166+
assert_ne!(gid1, 0);
167+
assert_ne!(gid2, 0);
168+
169+
// Must be different groups
170+
assert_ne!(gid1, gid2);
171+
172+
((primary_group, gid1), (secondary_group, gid2))
159173
}
160174

161175
fn file_gid(path: &str) -> io::Result<u32> {

0 commit comments

Comments
 (0)