@@ -15,7 +15,7 @@ use std::{
15
15
self ,
16
16
fs:: { MetadataExt , PermissionsExt } ,
17
17
} ,
18
- sync:: Once ,
18
+ sync:: { Once , RwLock } ,
19
19
thread,
20
20
time:: Duration ,
21
21
} ;
@@ -33,13 +33,21 @@ fn chgrp_test(args: &[&str], expected_output: &str, expected_error: &str, expect
33
33
} ) ;
34
34
}
35
35
36
- static INIT_GROUPS : Once = Once :: new ( ) ;
37
- static mut PRIMARY_GROUP : String = String :: new ( ) ;
38
- static mut SECONDARY_GROUP : String = String :: new ( ) ;
36
+ #[ derive( Clone ) ]
37
+ struct Groups {
38
+ primary_group : String ,
39
+ secondary_group : String ,
40
+ gid1 : u32 ,
41
+ gid2 : u32 ,
42
+ }
39
43
40
- static mut GID1 : u32 = 0 ;
41
- static mut GID2 : u32 = 0 ;
42
- static INIT_GID : Once = Once :: new ( ) ;
44
+ static INIT_GROUPS : Once = Once :: new ( ) ;
45
+ static GROUPS : RwLock < Groups > = RwLock :: new ( Groups {
46
+ primary_group : String :: new ( ) ,
47
+ secondary_group : String :: new ( ) ,
48
+ gid1 : 0 ,
49
+ gid2 : 0 ,
50
+ } ) ;
43
51
44
52
fn get_group_id ( name : & CStr ) -> u32 {
45
53
unsafe {
@@ -54,12 +62,22 @@ fn get_group_id(name: &CStr) -> u32 {
54
62
55
63
// Return two groups that the current user belongs to.
56
64
fn get_groups ( ) -> ( ( String , u32 ) , ( String , u32 ) ) {
57
- // Linux - (primary group of current user, a group in the supplemental group list that
58
- // is not the primary group)
59
- // macOS - ("staff", "admin")
60
- let ( g1, g2) = if cfg ! ( target_os = "linux" ) {
61
- unsafe {
62
- INIT_GROUPS . call_once ( || {
65
+ // Guard the writes to the GROUPS with a `Once`
66
+ INIT_GROUPS . call_once ( || {
67
+ let mut groups = GROUPS . write ( ) . unwrap ( ) ;
68
+ let Groups {
69
+ primary_group,
70
+ secondary_group,
71
+ gid1,
72
+ gid2,
73
+ } = & mut * groups;
74
+
75
+ // Initialize group strings
76
+ // Linux - (primary group of current user, a group in the supplemental group list that
77
+ // is not the primary group)
78
+ // macOS - ("staff", "admin")
79
+ if cfg ! ( target_os = "linux" ) {
80
+ unsafe {
63
81
let uid = libc:: getuid ( ) ;
64
82
let pw = libc:: getpwuid ( uid) ;
65
83
if pw. is_null ( ) {
@@ -73,7 +91,7 @@ fn get_groups() -> ((String, u32), (String, u32)) {
73
91
}
74
92
75
93
let gr_name = CStr :: from_ptr ( ( & * gr) . gr_name ) . to_owned ( ) ;
76
- PRIMARY_GROUP = gr_name. to_str ( ) . unwrap ( ) . to_owned ( ) ;
94
+ * primary_group = gr_name. to_str ( ) . unwrap ( ) . to_owned ( ) ;
77
95
78
96
let mut count = libc:: getgroups ( 0 , std:: ptr:: null_mut ( ) ) ;
79
97
if count < 0 {
@@ -83,18 +101,9 @@ fn get_groups() -> ((String, u32), (String, u32)) {
83
101
) ;
84
102
}
85
103
86
- let mut groups_ptr: * mut libc:: gid_t =
87
- libc:: malloc ( std:: mem:: size_of :: < libc:: gid_t > ( ) * count as usize )
88
- as * mut libc:: gid_t ;
89
-
90
- if groups_ptr. is_null ( ) {
91
- panic ! (
92
- "unable to allocate memory for groups list: {}" ,
93
- io:: Error :: last_os_error( )
94
- ) ;
95
- }
104
+ let mut groups_list: Vec < libc:: gid_t > = vec ! [ 0 ; count as usize ] ;
96
105
97
- count = libc:: getgroups ( count, groups_ptr ) ;
106
+ count = libc:: getgroups ( count, groups_list . as_mut_ptr ( ) ) ;
98
107
match count {
99
108
_ if count < 2 => panic ! ( "user must be a member of at least two groups" ) ,
100
109
-1 => panic ! (
@@ -104,58 +113,63 @@ fn get_groups() -> ((String, u32), (String, u32)) {
104
113
_ => { }
105
114
}
106
115
107
- for _ in 0 ..count {
108
- if groups_ptr. is_null ( ) {
109
- panic ! ( "unable to get second group: reached end of group list" ) ;
110
- }
111
-
116
+ for second_gid in groups_list {
112
117
// Skip over the primary_gid
113
- if * groups_ptr == primary_gid {
114
- groups_ptr = groups_ptr. offset ( 1 ) ;
118
+ if second_gid == primary_gid {
115
119
continue ;
116
120
} else {
117
- let second_gid = * groups_ptr;
118
121
let sec_grent = libc:: getgrgid ( second_gid) ;
119
122
if sec_grent. is_null ( ) {
120
123
panic ! ( "Unable to get group entry for secondary group id {second_gid}" ) ;
121
124
}
122
125
123
126
let sec_gr_name = CStr :: from_ptr ( ( & * sec_grent) . gr_name ) . to_owned ( ) ;
124
- SECONDARY_GROUP = sec_gr_name. to_str ( ) . unwrap ( ) . to_owned ( ) ;
127
+ * secondary_group = sec_gr_name. to_str ( ) . unwrap ( ) . to_owned ( ) ;
125
128
break ;
126
129
}
127
130
}
128
- if SECONDARY_GROUP == "" {
131
+
132
+ if secondary_group. is_empty ( ) {
129
133
panic ! ( "unable to find suitable secondary group" ) ;
130
134
}
131
- } ) ;
132
-
133
- ( PRIMARY_GROUP . clone ( ) , SECONDARY_GROUP . clone ( ) )
135
+ }
136
+ } else if cfg ! ( target_os = "macos" ) {
137
+ * primary_group = String :: from ( "staff" ) ;
138
+ * secondary_group = String :: from ( "admin" ) ;
139
+ } else {
140
+ panic ! ( "Unsupported OS" ) ;
134
141
}
135
- } else if cfg ! ( target_os = "macos" ) {
136
- ( "staff" . to_owned ( ) , "admin" . to_owned ( ) )
137
- } else {
138
- panic ! ( "Unsupported OS" )
139
- } ;
140
-
141
- unsafe {
142
- INIT_GID . call_once ( || {
143
- let g1_cstr = CString :: new ( g1. as_str ( ) ) . unwrap ( ) ;
144
- let g2_cstr = CString :: new ( g2. as_str ( ) ) . unwrap ( ) ;
145
-
146
- GID1 = get_group_id ( & g1_cstr) ;
147
- GID2 = get_group_id ( & g2_cstr) ;
148
- } ) ;
149
142
150
- // Must be initialized
151
- assert_ne ! ( GID1 , 0 ) ;
152
- assert_ne ! ( GID2 , 0 ) ;
143
+ // Initialize the group IDs corresponding to the group strings
144
+ {
145
+ let g1_cstr = CString :: new ( primary_group. as_str ( ) ) . unwrap ( ) ;
146
+ let g2_cstr = CString :: new ( secondary_group. as_str ( ) ) . unwrap ( ) ;
153
147
154
- // Must be different groups
155
- assert_ne ! ( GID1 , GID2 ) ;
148
+ * gid1 = get_group_id ( & g1_cstr) ;
149
+ * gid2 = get_group_id ( & g2_cstr) ;
150
+ }
151
+ } ) ;
156
152
157
- ( ( g1, GID1 ) , ( g2, GID2 ) )
158
- }
153
+ // The reads to GROUPS should not have conflicts with the writes because:
154
+ // 1) `Once` will block until the initialization is finished.
155
+ // 2) The writes are only done inside `Once::call_once`.
156
+ let groups = GROUPS . read ( ) . unwrap ( ) ;
157
+ let groups_ref = & * groups;
158
+ let Groups {
159
+ primary_group,
160
+ secondary_group,
161
+ gid1,
162
+ gid2,
163
+ } = groups_ref. clone ( ) ;
164
+
165
+ // Must be initialized
166
+ assert_ne ! ( gid1, 0 ) ;
167
+ assert_ne ! ( gid2, 0 ) ;
168
+
169
+ // Must be different groups
170
+ assert_ne ! ( gid1, gid2) ;
171
+
172
+ ( ( primary_group, gid1) , ( secondary_group, gid2) )
159
173
}
160
174
161
175
fn file_gid ( path : & str ) -> io:: Result < u32 > {
0 commit comments