@@ -30,15 +30,30 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
30
30
use crate :: torchftpb:: { CheckpointMetadataRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
31
31
use pyo3:: prelude:: * ;
32
32
33
+ /// ManagerServer is a GRPC server for the manager service.
34
+ /// There should be one manager server per replica group (typically running on
35
+ /// the rank 0 host). The individual ranks within a replica group should use
36
+ /// ManagerClient to communicate with the manager server and participate in
37
+ /// quorum operations.
38
+ ///
39
+ /// Args:
40
+ /// replica_id (str): The ID of the replica group.
41
+ /// lighthouse_addr (str): The HTTP address of the lighthouse server.
42
+ /// hostname (str): The hostname of the manager server.
43
+ /// bind (str): The HTTP address to bind the server to.
44
+ /// store_addr (str): The HTTP address of the store server.
45
+ /// world_size (int): The world size of the replica group.
46
+ /// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
47
+ /// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
33
48
#[ pyclass]
34
- struct Manager {
49
+ struct ManagerServer {
35
50
handle : JoinHandle < Result < ( ) > > ,
36
51
manager : Arc < manager:: Manager > ,
37
52
_runtime : Runtime ,
38
53
}
39
54
40
55
#[ pymethods]
41
- impl Manager {
56
+ impl ManagerServer {
42
57
#[ new]
43
58
fn new (
44
59
py : Python < ' _ > ,
@@ -74,17 +89,29 @@ impl Manager {
74
89
} )
75
90
}
76
91
92
+ /// address returns the address of the manager server.
93
+ ///
94
+ /// Returns:
95
+ /// str: The address of the manager server.
77
96
fn address ( & self ) -> PyResult < String > {
78
97
Ok ( self . manager . address ( ) . to_string ( ) )
79
98
}
80
99
100
+ /// shutdown shuts down the manager server.
81
101
fn shutdown ( & self , py : Python < ' _ > ) {
82
102
py. allow_threads ( move || {
83
103
self . handle . abort ( ) ;
84
104
} )
85
105
}
86
106
}
87
107
108
+ /// ManagerClient is a GRPC client to the manager service.
109
+ ///
110
+ /// It is used by the trainer to communicate with the ManagerServer.
111
+ ///
112
+ /// Args:
113
+ /// addr (str): The HTTP address of the manager server.
114
+ /// connect_timeout (timedelta): The timeout for connecting to the manager server.
88
115
#[ pyclass]
89
116
struct ManagerClient {
90
117
runtime : Runtime ,
@@ -108,7 +135,7 @@ impl ManagerClient {
108
135
} )
109
136
}
110
137
111
- fn quorum (
138
+ fn _quorum (
112
139
& self ,
113
140
py : Python < ' _ > ,
114
141
rank : i64 ,
@@ -147,7 +174,7 @@ impl ManagerClient {
147
174
} )
148
175
}
149
176
150
- fn checkpoint_metadata (
177
+ fn _checkpoint_metadata (
151
178
& self ,
152
179
py : Python < ' _ > ,
153
180
rank : i64 ,
@@ -168,6 +195,20 @@ impl ManagerClient {
168
195
} )
169
196
}
170
197
198
+ /// should_commit makes a request to the manager to determine if the trainer
199
+ /// should commit the current step. This waits until all ranks check in at
200
+ /// the specified step and will return false if any worker passes
201
+ /// ``should_commit=False``.
202
+ ///
203
+ /// Args:
204
+ /// rank (int): The rank of the trainer.
205
+ /// step (int): The step of the trainer.
206
+ /// should_commit (bool): Whether the trainer should commit the current step.
207
+ /// timeout (timedelta): The timeout for the request. If the request
208
+ /// times out a TimeoutError is raised.
209
+ ///
210
+ /// Returns:
211
+ /// bool: Whether the trainer should commit the current step.
171
212
fn should_commit (
172
213
& self ,
173
214
py : Python < ' _ > ,
@@ -263,15 +304,28 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
263
304
Ok ( ( ) )
264
305
}
265
306
307
+ /// LighthouseServer is a GRPC server for the lighthouse service.
308
+ ///
309
+ /// It is used to coordinate the ManagerServer for each replica group.
310
+ ///
311
+ /// This entrypoint is primarily for testing and debugging purposes. The
312
+ /// ``torchft_lighthouse`` command is recommended for most use cases.
313
+ ///
314
+ /// Args:
315
+ /// bind (str): The HTTP address to bind the server to.
316
+ /// min_replicas (int): The minimum number of replicas required to form a quorum.
317
+ /// join_timeout_ms (int): The timeout for joining the quorum.
318
+ /// quorum_tick_ms (int): The interval at which the quorum is checked.
319
+ /// heartbeat_timeout_ms (int): The timeout for heartbeats.
266
320
#[ pyclass]
267
- struct Lighthouse {
321
+ struct LighthouseServer {
268
322
lighthouse : Arc < lighthouse:: Lighthouse > ,
269
323
handle : JoinHandle < Result < ( ) > > ,
270
324
_runtime : Runtime ,
271
325
}
272
326
273
327
#[ pymethods]
274
- impl Lighthouse {
328
+ impl LighthouseServer {
275
329
#[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None ) ) ]
276
330
#[ new]
277
331
fn new (
@@ -307,10 +361,15 @@ impl Lighthouse {
307
361
} )
308
362
}
309
363
364
+ /// address returns the address of the lighthouse server.
365
+ ///
366
+ /// Returns:
367
+ /// str: The address of the lighthouse server.
310
368
fn address ( & self ) -> PyResult < String > {
311
369
Ok ( self . lighthouse . address ( ) . to_string ( ) )
312
370
}
313
371
372
+ /// shutdown shuts down the lighthouse server.
314
373
fn shutdown ( & self , py : Python < ' _ > ) {
315
374
py. allow_threads ( move || {
316
375
self . handle . abort ( ) ;
@@ -339,7 +398,7 @@ impl From<Status> for StatusError {
339
398
}
340
399
341
400
#[ pymodule]
342
- fn torchft ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
401
+ fn _torchft ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
343
402
// setup logging on import
344
403
let mut log = stderrlog:: new ( ) ;
345
404
log. verbosity ( 2 )
@@ -353,9 +412,9 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
353
412
log. init ( )
354
413
. map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
355
414
356
- m. add_class :: < Manager > ( ) ?;
415
+ m. add_class :: < ManagerServer > ( ) ?;
357
416
m. add_class :: < ManagerClient > ( ) ?;
358
- m. add_class :: < Lighthouse > ( ) ?;
417
+ m. add_class :: < LighthouseServer > ( ) ?;
359
418
m. add_class :: < QuorumResult > ( ) ?;
360
419
m. add_function ( wrap_pyfunction ! ( lighthouse_main, m) ?) ?;
361
420
0 commit comments