-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathlib.rs
422 lines (377 loc) · 13 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
pub mod lighthouse;
pub mod manager;
mod net;
mod retry;
mod timeout;
use core::time::Duration;
use std::env;
use std::sync::Arc;
use anyhow::Result;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use structopt::StructOpt;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tonic::transport::Channel;
use tonic::Status;
pub mod torchftpb {
tonic::include_proto!("torchft");
}
use crate::torchftpb::manager_service_client::ManagerServiceClient;
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
use pyo3::prelude::*;
/// ManagerServer is a GRPC server for the manager service.
/// There should be one manager server per replica group (typically running on
/// the rank 0 host). The individual ranks within a replica group should use
/// ManagerClient to communicate with the manager server and participate in
/// quorum operations.
///
/// Args:
/// replica_id (str): The ID of the replica group.
/// lighthouse_addr (str): The HTTP address of the lighthouse server.
/// hostname (str): The hostname of the manager server.
/// bind (str): The HTTP address to bind the server to.
/// store_addr (str): The HTTP address of the store server.
/// world_size (int): The world size of the replica group.
/// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
#[pyclass]
struct ManagerServer {
handle: JoinHandle<Result<()>>,
manager: Arc<manager::Manager>,
_runtime: Runtime,
}
#[pymethods]
impl ManagerServer {
#[new]
fn new(
py: Python<'_>,
replica_id: String,
lighthouse_addr: String,
hostname: String,
bind: String,
store_addr: String,
world_size: u64,
heartbeat_interval: Duration,
connect_timeout: Duration,
) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = Runtime::new()?;
let manager = runtime
.block_on(manager::Manager::new(
replica_id,
lighthouse_addr,
hostname,
bind,
store_addr,
world_size,
heartbeat_interval,
connect_timeout,
))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let handle = runtime.spawn(manager.clone().run());
Ok(Self {
handle: handle,
manager: manager,
_runtime: runtime,
})
})
}
/// address returns the address of the manager server.
///
/// Returns:
/// str: The address of the manager server.
fn address(&self) -> PyResult<String> {
Ok(self.manager.address().to_string())
}
/// shutdown shuts down the manager server.
fn shutdown(&self, py: Python<'_>) {
py.allow_threads(move || {
self.handle.abort();
})
}
}
/// ManagerClient is a GRPC client to the manager service.
///
/// It is used by the trainer to communicate with the ManagerServer.
///
/// Args:
/// addr (str): The HTTP address of the manager server.
/// connect_timeout (timedelta): The timeout for connecting to the manager server.
#[pyclass]
struct ManagerClient {
runtime: Runtime,
client: ManagerServiceClient<Channel>,
}
#[pymethods]
impl ManagerClient {
#[new]
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = Runtime::new()?;
let client = runtime
.block_on(manager::manager_client_new(addr, connect_timeout))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(Self {
runtime: runtime,
client: client,
})
})
}
fn _quorum(
&self,
py: Python<'_>,
rank: i64,
step: i64,
checkpoint_metadata: String,
shrink_only: bool,
timeout: Duration,
) -> Result<QuorumResult, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(ManagerQuorumRequest {
rank: rank,
step: step,
checkpoint_metadata: checkpoint_metadata,
shrink_only: shrink_only,
});
// This timeout is processed on the server side so we also enable
// keep alives to detect server health.
request.set_timeout(timeout);
let response = self.runtime.block_on(self.client.clone().quorum(request))?;
let resp = response.into_inner();
Ok(QuorumResult {
quorum_id: resp.quorum_id,
replica_rank: resp.replica_rank,
replica_world_size: resp.replica_world_size,
recover_src_manager_address: resp.recover_src_manager_address,
recover_src_rank: resp.recover_src_rank,
recover_dst_ranks: resp.recover_dst_ranks,
store_address: resp.store_address,
max_step: resp.max_step,
max_rank: resp.max_rank,
max_world_size: resp.max_world_size,
heal: resp.heal,
})
})
}
fn _checkpoint_metadata(
&self,
py: Python<'_>,
rank: i64,
timeout: Duration,
) -> Result<String, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(CheckpointMetadataRequest { rank: rank });
// This timeout is processed on the server side so we also enable
// keep alives to detect server health.
request.set_timeout(timeout);
let response = self
.runtime
.block_on(self.client.clone().checkpoint_metadata(request))?;
let resp = response.into_inner();
Ok(resp.checkpoint_metadata)
})
}
/// should_commit makes a request to the manager to determine if the trainer
/// should commit the current step. This waits until all ranks check in at
/// the specified step and will return false if any worker passes
/// ``should_commit=False``.
///
/// Args:
/// rank (int): The rank of the trainer.
/// step (int): The step of the trainer.
/// should_commit (bool): Whether the trainer should commit the current step.
/// timeout (timedelta): The timeout for the request. If the request
/// times out a TimeoutError is raised.
///
/// Returns:
/// bool: Whether the trainer should commit the current step.
fn should_commit(
&self,
py: Python<'_>,
rank: i64,
step: i64,
should_commit: bool,
timeout: Duration,
) -> Result<bool, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(ShouldCommitRequest {
rank: rank,
step: step,
should_commit: should_commit,
});
// This notifies the server about the timeout but doesn't affect the
// endpoint timeout which we set on client creation.
request.set_timeout(timeout);
let response = self
.runtime
.block_on(self.client.clone().should_commit(request))?;
let resp = response.into_inner();
Ok(resp.should_commit)
})
}
}
#[pyclass(get_all, set_all)]
struct QuorumResult {
quorum_id: i64,
replica_rank: i64,
replica_world_size: i64,
recover_src_manager_address: String,
recover_src_rank: Option<i64>,
recover_dst_ranks: Vec<i64>,
store_address: String,
max_step: i64,
max_rank: Option<i64>,
max_world_size: i64,
heal: bool,
}
#[pymethods]
impl QuorumResult {
#[new]
fn new() -> Self {
Self {
quorum_id: 0,
replica_rank: 0,
replica_world_size: 1,
recover_src_manager_address: "".to_string(),
recover_src_rank: None,
recover_dst_ranks: Vec::new(),
store_address: "".to_string(),
max_step: 0,
max_rank: None,
max_world_size: 1,
heal: false,
}
}
}
fn reset_python_signals(py: Python<'_>) -> PyResult<()> {
// clear python signal handlers
// signal.signal(signal.SIGINT, signal.SIG_DFL)
let signal = py.import_bound("signal")?;
let set_signal = signal.getattr("signal")?;
let args = (signal.getattr("SIGINT")?, signal.getattr("SIG_DFL")?);
set_signal.call1(args)?;
Ok(())
}
#[pyfunction]
fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
reset_python_signals(py)?;
let mut args = env::args();
args.next(); // discard binary arg
let opt = lighthouse::LighthouseOpt::from_iter(args);
let rt = Runtime::new()?;
rt.block_on(lighthouse_main_async(opt))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(())
}
async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
let lighthouse = lighthouse::Lighthouse::new(opt).await?;
lighthouse.run().await?;
Ok(())
}
/// LighthouseServer is a GRPC server for the lighthouse service.
///
/// It is used to coordinate the ManagerServer for each replica group.
///
/// This entrypoint is primarily for testing and debugging purposes. The
/// ``torchft_lighthouse`` command is recommended for most use cases.
///
/// Args:
/// bind (str): The HTTP address to bind the server to.
/// min_replicas (int): The minimum number of replicas required to form a quorum.
/// join_timeout_ms (int): The timeout for joining the quorum.
/// quorum_tick_ms (int): The interval at which the quorum is checked.
/// heartbeat_timeout_ms (int): The timeout for heartbeats.
#[pyclass]
struct LighthouseServer {
lighthouse: Arc<lighthouse::Lighthouse>,
handle: JoinHandle<Result<()>>,
_runtime: Runtime,
}
#[pymethods]
impl LighthouseServer {
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))]
#[new]
fn new(
py: Python<'_>,
bind: String,
min_replicas: u64,
join_timeout_ms: Option<u64>,
quorum_tick_ms: Option<u64>,
heartbeat_timeout_ms: Option<u64>,
) -> PyResult<Self> {
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);
let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000);
py.allow_threads(move || {
let rt = Runtime::new()?;
let lighthouse = rt
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
bind: bind,
min_replicas: min_replicas,
join_timeout_ms: join_timeout_ms,
quorum_tick_ms: quorum_tick_ms,
heartbeat_timeout_ms: heartbeat_timeout_ms,
}))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(Self {
handle: rt.spawn(lighthouse.clone().run()),
lighthouse: lighthouse,
_runtime: rt,
})
})
}
/// address returns the address of the lighthouse server.
///
/// Returns:
/// str: The address of the lighthouse server.
fn address(&self) -> PyResult<String> {
Ok(self.lighthouse.address().to_string())
}
/// shutdown shuts down the lighthouse server.
fn shutdown(&self, py: Python<'_>) {
py.allow_threads(move || {
self.handle.abort();
})
}
}
struct StatusError(Status);
impl From<StatusError> for PyErr {
fn from(error: StatusError) -> Self {
let code = error.0.code();
match code {
tonic::Code::Cancelled | tonic::Code::DeadlineExceeded => {
PyTimeoutError::new_err(error.0.to_string())
}
_ => PyRuntimeError::new_err(error.0.to_string()),
}
}
}
impl From<Status> for StatusError {
fn from(other: Status) -> Self {
Self(other)
}
}
#[pymodule]
fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
let mut log = stderrlog::new();
log.verbosity(2)
.show_module_names(true)
.timestamp(stderrlog::Timestamp::Millisecond);
if env::var("CLICOLOR_FORCE").is_ok() {
log.color(stderrlog::ColorChoice::AlwaysAnsi);
}
log.init()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
m.add_class::<ManagerServer>()?;
m.add_class::<ManagerClient>()?;
m.add_class::<LighthouseServer>()?;
m.add_class::<QuorumResult>()?;
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;
Ok(())
}