Skip to content

Commit bf27956

Browse files
authored
coordination: expose new low level torchft coordination API (#84)
1 parent 8628a3f commit bf27956

12 files changed

+151
-44
lines changed

docs/source/coordination.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.. automodule:: torchft.coordination
2+
:members:
3+
:undoc-members:
4+
:show-inheritance:

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ the entire training job.
2121
data
2222
checkpointing
2323
parameter_server
24+
coordination
2425

2526

2627
License

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dev = [
3333

3434
[tool.maturin]
3535
features = ["pyo3/extension-module"]
36+
module-name = "torchft._torchft"
3637

3738
[project.scripts]
3839
torchft_lighthouse = "torchft.torchft:lighthouse_main"

src/lib.rs

+68-9
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,30 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
3030
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
3131
use pyo3::prelude::*;
3232

33+
/// ManagerServer is a GRPC server for the manager service.
34+
/// There should be one manager server per replica group (typically running on
35+
/// the rank 0 host). The individual ranks within a replica group should use
36+
/// ManagerClient to communicate with the manager server and participate in
37+
/// quorum operations.
38+
///
39+
/// Args:
40+
/// replica_id (str): The ID of the replica group.
41+
/// lighthouse_addr (str): The HTTP address of the lighthouse server.
42+
/// hostname (str): The hostname of the manager server.
43+
/// bind (str): The HTTP address to bind the server to.
44+
/// store_addr (str): The HTTP address of the store server.
45+
/// world_size (int): The world size of the replica group.
46+
/// heartbeat_interval (timedelta): The interval at which heartbeats are sent.
47+
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
3348
#[pyclass]
34-
struct Manager {
49+
struct ManagerServer {
3550
handle: JoinHandle<Result<()>>,
3651
manager: Arc<manager::Manager>,
3752
_runtime: Runtime,
3853
}
3954

4055
#[pymethods]
41-
impl Manager {
56+
impl ManagerServer {
4257
#[new]
4358
fn new(
4459
py: Python<'_>,
@@ -74,17 +89,29 @@ impl Manager {
7489
})
7590
}
7691

92+
/// address returns the address of the manager server.
93+
///
94+
/// Returns:
95+
/// str: The address of the manager server.
7796
fn address(&self) -> PyResult<String> {
7897
Ok(self.manager.address().to_string())
7998
}
8099

100+
/// shutdown shuts down the manager server.
81101
fn shutdown(&self, py: Python<'_>) {
82102
py.allow_threads(move || {
83103
self.handle.abort();
84104
})
85105
}
86106
}
87107

108+
/// ManagerClient is a GRPC client to the manager service.
109+
///
110+
/// It is used by the trainer to communicate with the ManagerServer.
111+
///
112+
/// Args:
113+
/// addr (str): The HTTP address of the manager server.
114+
/// connect_timeout (timedelta): The timeout for connecting to the manager server.
88115
#[pyclass]
89116
struct ManagerClient {
90117
runtime: Runtime,
@@ -108,7 +135,7 @@ impl ManagerClient {
108135
})
109136
}
110137

111-
fn quorum(
138+
fn _quorum(
112139
&self,
113140
py: Python<'_>,
114141
rank: i64,
@@ -147,7 +174,7 @@ impl ManagerClient {
147174
})
148175
}
149176

150-
fn checkpoint_metadata(
177+
fn _checkpoint_metadata(
151178
&self,
152179
py: Python<'_>,
153180
rank: i64,
@@ -168,6 +195,20 @@ impl ManagerClient {
168195
})
169196
}
170197

198+
/// should_commit makes a request to the manager to determine if the trainer
199+
/// should commit the current step. This waits until all ranks check in at
200+
/// the specified step and will return false if any worker passes
201+
/// ``should_commit=False``.
202+
///
203+
/// Args:
204+
/// rank (int): The rank of the trainer.
205+
/// step (int): The step of the trainer.
206+
/// should_commit (bool): Whether the trainer should commit the current step.
207+
/// timeout (timedelta): The timeout for the request. If the request
208+
/// times out a TimeoutError is raised.
209+
///
210+
/// Returns:
211+
/// bool: Whether the trainer should commit the current step.
171212
fn should_commit(
172213
&self,
173214
py: Python<'_>,
@@ -263,15 +304,28 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
263304
Ok(())
264305
}
265306

307+
/// LighthouseServer is a GRPC server for the lighthouse service.
308+
///
309+
/// It is used to coordinate the ManagerServer for each replica group.
310+
///
311+
/// This entrypoint is primarily for testing and debugging purposes. The
312+
/// ``torchft_lighthouse`` command is recommended for most use cases.
313+
///
314+
/// Args:
315+
/// bind (str): The HTTP address to bind the server to.
316+
/// min_replicas (int): The minimum number of replicas required to form a quorum.
317+
/// join_timeout_ms (int): The timeout for joining the quorum.
318+
/// quorum_tick_ms (int): The interval at which the quorum is checked.
319+
/// heartbeat_timeout_ms (int): The timeout for heartbeats.
266320
#[pyclass]
267-
struct Lighthouse {
321+
struct LighthouseServer {
268322
lighthouse: Arc<lighthouse::Lighthouse>,
269323
handle: JoinHandle<Result<()>>,
270324
_runtime: Runtime,
271325
}
272326

273327
#[pymethods]
274-
impl Lighthouse {
328+
impl LighthouseServer {
275329
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))]
276330
#[new]
277331
fn new(
@@ -307,10 +361,15 @@ impl Lighthouse {
307361
})
308362
}
309363

364+
/// address returns the address of the lighthouse server.
365+
///
366+
/// Returns:
367+
/// str: The address of the lighthouse server.
310368
fn address(&self) -> PyResult<String> {
311369
Ok(self.lighthouse.address().to_string())
312370
}
313371

372+
/// shutdown shuts down the lighthouse server.
314373
fn shutdown(&self, py: Python<'_>) {
315374
py.allow_threads(move || {
316375
self.handle.abort();
@@ -339,7 +398,7 @@ impl From<Status> for StatusError {
339398
}
340399

341400
#[pymodule]
342-
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
401+
fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
343402
// setup logging on import
344403
let mut log = stderrlog::new();
345404
log.verbosity(2)
@@ -353,9 +412,9 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
353412
log.init()
354413
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
355414

356-
m.add_class::<Manager>()?;
415+
m.add_class::<ManagerServer>()?;
357416
m.add_class::<ManagerClient>()?;
358-
m.add_class::<Lighthouse>()?;
417+
m.add_class::<LighthouseServer>()?;
359418
m.add_class::<QuorumResult>()?;
360419
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;
361420

torchft/torchft.pyi torchft/_torchft.pyi

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ from typing import List, Optional
33

44
class ManagerClient:
55
def __init__(self, addr: str, connect_timeout: timedelta) -> None: ...
6-
def quorum(
6+
def _quorum(
77
self,
88
rank: int,
99
step: int,
1010
checkpoint_metadata: str,
1111
shrink_only: bool,
1212
timeout: timedelta,
1313
) -> QuorumResult: ...
14-
def checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
14+
def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
1515
def should_commit(
1616
self,
1717
rank: int,
@@ -33,7 +33,7 @@ class QuorumResult:
3333
max_world_size: int
3434
heal: bool
3535

36-
class Manager:
36+
class ManagerServer:
3737
def __init__(
3838
self,
3939
replica_id: str,
@@ -48,7 +48,7 @@ class Manager:
4848
def address(self) -> str: ...
4949
def shutdown(self) -> None: ...
5050

51-
class Lighthouse:
51+
class LighthouseServer:
5252
def __init__(
5353
self,
5454
bind: str,

torchft/coordination.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
Coordination (Low Level API)
3+
============================
4+
5+
.. warning::
6+
As torchft is still in development, the APIs in this module are subject to change.
7+
8+
This module exposes low level coordination APIs to allow you to build your own
9+
custom fault tolerance algorithms on top of torchft.
10+
11+
If you're looking for a more complete solution, please use the other modules in
12+
torchft.
13+
14+
This provides direct access to the Lighthouse and Manager servers and clients.
15+
"""
16+
17+
from torchft._torchft import LighthouseServer, ManagerClient, ManagerServer
18+
19+
__all__ = [
20+
"LighthouseServer",
21+
"ManagerServer",
22+
"ManagerClient",
23+
]

torchft/coordination_test.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import inspect
2+
from unittest import TestCase
3+
4+
from torchft.coordination import LighthouseServer, ManagerClient, ManagerServer
5+
6+
7+
class TestCoordination(TestCase):
8+
def test_coordination_docs(self) -> None:
9+
classes = [
10+
ManagerClient,
11+
ManagerServer,
12+
LighthouseServer,
13+
]
14+
for cls in classes:
15+
self.assertIn("Args:", str(cls.__doc__), cls)
16+
for name, method in inspect.getmembers(cls, predicate=inspect.ismethod):
17+
if name.startswith("_"):
18+
continue
19+
self.assertIn("Args:", str(cls.__doc__), cls)

torchft/lighthouse_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import torch.distributed as dist
55

66
from torchft import Manager, ProcessGroupGloo
7-
from torchft.torchft import Lighthouse
7+
from torchft._torchft import LighthouseServer
88

99

1010
class TestLighthouse(TestCase):
1111
def test_join_timeout_behavior(self) -> None:
1212
"""Test that join_timeout_ms affects joining behavior"""
1313
# To test, we create a lighthouse with 100ms and 400ms join timeouts
1414
# and measure the time taken to validate the quorum.
15-
lighthouse = Lighthouse(
15+
lighthouse = LighthouseServer(
1616
bind="[::]:0",
1717
min_replicas=1,
1818
join_timeout_ms=100,
@@ -52,14 +52,14 @@ def test_join_timeout_behavior(self) -> None:
5252
if "manager" in locals():
5353
manager.shutdown()
5454

55-
lighthouse = Lighthouse(
55+
lighthouse = LighthouseServer(
5656
bind="[::]:0",
5757
min_replicas=1,
5858
join_timeout_ms=400,
5959
)
6060

6161
def test_heartbeat_timeout_ms_sanity(self) -> None:
62-
lighthouse = Lighthouse(
62+
lighthouse = LighthouseServer(
6363
bind="[::]:0",
6464
min_replicas=1,
6565
heartbeat_timeout_ms=100,

torchft/local_sgd_integ_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
import torch
99
from torch import nn, optim
1010

11+
from torchft._torchft import LighthouseServer
1112
from torchft.local_sgd import DiLoCo, LocalSGD
1213
from torchft.manager import Manager
1314
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
1415
from torchft.process_group import ProcessGroupGloo
15-
from torchft.torchft import Lighthouse
1616

1717
logger: logging.Logger = logging.getLogger(__name__)
1818

@@ -166,7 +166,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
166166

167167
class ManagerIntegTest(TestCase):
168168
def test_local_sgd_recovery(self) -> None:
169-
lighthouse = Lighthouse(
169+
lighthouse = LighthouseServer(
170170
bind="[::]:0",
171171
min_replicas=2,
172172
)
@@ -214,7 +214,7 @@ def test_local_sgd_recovery(self) -> None:
214214
self.assertEqual(failure_injectors[1].count, 1)
215215

216216
def test_diloco_healthy(self) -> None:
217-
lighthouse = Lighthouse(
217+
lighthouse = LighthouseServer(
218218
bind="[::]:0",
219219
min_replicas=2,
220220
)
@@ -258,7 +258,7 @@ def test_diloco_healthy(self) -> None:
258258
)
259259

260260
def test_diloco_recovery(self) -> None:
261-
lighthouse = Lighthouse(
261+
lighthouse = LighthouseServer(
262262
bind="[::]:0",
263263
min_replicas=2,
264264
)

torchft/manager.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
import torch
3939
from torch.distributed import ReduceOp, TCPStore
4040

41+
from torchft._torchft import ManagerClient, ManagerServer
4142
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4243
from torchft.futures import future_timeout
43-
from torchft.torchft import Manager as _Manager, ManagerClient
4444

4545
if TYPE_CHECKING:
4646
from torchft.process_group import ProcessGroup
@@ -180,7 +180,7 @@ def __init__(
180180
wait_for_workers=False,
181181
)
182182
self._pg = pg
183-
self._manager: Optional[_Manager] = None
183+
self._manager: Optional[ManagerServer] = None
184184

185185
if rank == 0:
186186
if port is None:
@@ -192,7 +192,7 @@ def __init__(
192192
if replica_id is None:
193193
replica_id = ""
194194
replica_id = replica_id + str(uuid.uuid4())
195-
self._manager = _Manager(
195+
self._manager = ManagerServer(
196196
replica_id=replica_id,
197197
lighthouse_addr=lighthouse_addr,
198198
hostname=hostname,
@@ -429,7 +429,7 @@ def wait_quorum(self) -> None:
429429
def _async_quorum(
430430
self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta
431431
) -> None:
432-
quorum = self._client.quorum(
432+
quorum = self._client._quorum(
433433
rank=self._rank,
434434
step=self._step,
435435
checkpoint_metadata=self._checkpoint_transport.metadata(),
@@ -498,7 +498,7 @@ def _async_quorum(
498498
primary_client = ManagerClient(
499499
recover_src_manager_address, connect_timeout=self._connect_timeout
500500
)
501-
checkpoint_metadata = primary_client.checkpoint_metadata(
501+
checkpoint_metadata = primary_client._checkpoint_metadata(
502502
self._rank, timeout=self._timeout
503503
)
504504
recover_src_rank = quorum.recover_src_rank

0 commit comments

Comments
 (0)