Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to subscribe to multiple keys and to prevent propagation #49

Merged
merged 19 commits into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Sources/Defaults/Defaults.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// MIT License © Sindre Sorhus
import Foundation

public protocol _DefaultsBaseKey: Defaults.Keys {
public protocol DefaultsBaseKey: Defaults.Keys {
fredyshox marked this conversation as resolved.
Show resolved Hide resolved
var name: String { get }
var suite: UserDefaults { get }
}

extension _DefaultsBaseKey {
extension DefaultsBaseKey {
/// Reset the item back to its default value.
public func reset() {
suite.removeObject(forKey: name)
Expand All @@ -26,7 +26,7 @@ public enum Defaults {
fileprivate init() {}
}

public final class Key<Value: Codable>: Keys, _DefaultsBaseKey {
public final class Key<Value: Codable>: Keys, DefaultsBaseKey {
public let name: String
public let defaultValue: Value
public let suite: UserDefaults
Expand All @@ -53,7 +53,7 @@ public enum Defaults {
}

@available(iOS 11.0, macOS 10.13, tvOS 11.0, watchOS 4.0, iOSApplicationExtension 11.0, macOSApplicationExtension 10.13, tvOSApplicationExtension 11.0, watchOSApplicationExtension 4.0, *)
public final class NSSecureCodingKey<Value: NSSecureCoding>: Keys, _DefaultsBaseKey {
public final class NSSecureCodingKey<Value: NSSecureCoding>: Keys, DefaultsBaseKey {
public let name: String
public let defaultValue: Value
public let suite: UserDefaults
Expand All @@ -80,7 +80,7 @@ public enum Defaults {
}

@available(iOS 11.0, macOS 10.13, tvOS 11.0, watchOS 4.0, iOSApplicationExtension 11.0, macOSApplicationExtension 10.13, tvOSApplicationExtension 11.0, watchOSApplicationExtension 4.0, *)
public final class NSSecureCodingOptionalKey<Value: NSSecureCoding>: Keys, _DefaultsBaseKey {
public final class NSSecureCodingOptionalKey<Value: NSSecureCoding>: Keys, DefaultsBaseKey {
public let name: String
public let suite: UserDefaults

Expand Down
2 changes: 1 addition & 1 deletion Sources/Defaults/Observation+Combine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ extension Defaults {
*/
@available(iOS 13.0, macOS 10.15, tvOS 13.0, watchOS 6.0, iOSApplicationExtension 13.0, macOSApplicationExtension 10.15, tvOSApplicationExtension 13.0, watchOSApplicationExtension 6.0, *)
public static func publisher(
keys: _DefaultsBaseKey...,
keys: DefaultsBaseKey...,
options: ObservationOptions = [.initial]
) -> AnyPublisher<Void, Never> {
let initial = Empty<Void, Never>(completeImmediately: false).eraseToAnyPublisher()
Expand Down
114 changes: 113 additions & 1 deletion Sources/Defaults/Observation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,105 @@ extension Defaults {
else {
return
}

callback(BaseChange(change: change))
}
}

class CompositeUserDefaultsKeyObservation: NSObject, Observation {
private static var observationContext = 0

class SuiteKeyPair {
weak var suite: UserDefaults?
let key: String

init(suite: UserDefaults, key: String) {
self.suite = suite
self.key = key
}
}

private var observables: [SuiteKeyPair]
private var lifetimeAssociation: LifetimeAssociation? = nil
private let preventPropagation: Bool
private let callback: UserDefaultsKeyObservation.Callback

init(observables: [SuiteKeyPair], preventPropagation: Bool, callback: @escaping UserDefaultsKeyObservation.Callback) {
self.observables = observables
self.preventPropagation = preventPropagation
self.callback = callback
super.init()
}

deinit {
invalidate()
}

public func start(options: ObservationOptions) {
for observable in observables {
observable.suite?.addObserver(self,
forKeyPath: observable.key,
options: options.toNSKeyValueObservingOptions,
context: &type(of: self).observationContext)
}
}

public func invalidate() {
for observable in observables {
observable.suite?.removeObserver(self, forKeyPath: observable.key, context: &type(of: self).observationContext)
observable.suite = nil
}
lifetimeAssociation?.cancel()
}

public func tieToLifetime(of weaklyHeldObject: AnyObject) -> Self {
lifetimeAssociation = LifetimeAssociation(of: self, with: weaklyHeldObject, deinitHandler: { [weak self] in
self?.invalidate()
})

return self
}

public func removeLifetimeTie() {
lifetimeAssociation?.cancel()
}

// swiftlint:disable:next block_based_kvo
override func observeValue(
forKeyPath keyPath: String?,
of object: Any?,
change: [NSKeyValueChangeKey: Any]?, // swiftlint:disable:this discouraged_optional_collection
context: UnsafeMutableRawPointer?
) {
guard
context == &type(of: self).observationContext
else {
super.observeValue(forKeyPath: keyPath, of: object, change: change, context: context)
return
}

guard
object is UserDefaults,
let change = change
else {
return
}

if preventPropagation {
let key = "\(type(of: self))_updatingValuesFlag"
let updatingValuesFlag = (Thread.current.threadDictionary[key] as? Bool) ?? false
if updatingValuesFlag {
return
}

Thread.current.threadDictionary[key] = true
callback(BaseChange(change: change))
Thread.current.threadDictionary[key] = false
} else {
callback(BaseChange(change: change))
}
}
}

/**
Observe a defaults key.
Expand Down Expand Up @@ -268,6 +363,23 @@ extension Defaults {
observation.start(options: options)
return observation
}

public static func observe(
fredyshox marked this conversation as resolved.
Show resolved Hide resolved
keys: DefaultsBaseKey...,
sindresorhus marked this conversation as resolved.
Show resolved Hide resolved
options: ObservationOptions = [.initial],
preventPropagation: Bool = false,
handler: @escaping () -> Void
) -> Observation {
let pairs = keys.map {
CompositeUserDefaultsKeyObservation.SuiteKeyPair(suite: $0.suite, key: $0.name)
}
let compositeObservation = CompositeUserDefaultsKeyObservation(observables: pairs, preventPropagation: preventPropagation) { _ in
handler()
}
compositeObservation.start(options: options)

return compositeObservation
}
}

extension Defaults.ObservationOptions {
Expand Down
2 changes: 1 addition & 1 deletion Sources/Defaults/Reset.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ TODO: When Swift gets support for static key paths, all of this could be simplif

```
extension Defaults {
public static func reset(_ keys: KeyPath<Keys, _DefaultsBaseKey>...) {
public static func reset(_ keys: KeyPath<Keys, DefaultsBaseKey>...) {
for key in keys {
Keys[keyPath: key].reset()
}
Expand Down
94 changes: 94 additions & 0 deletions Tests/DefaultsTests/DefaultsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,53 @@ final class DefaultsTests: XCTestCase {

waitForExpectations(timeout: 10)
}

func testObserveMultipleKeys() {
let key1 = Defaults.Key<String>("observeKey1", default: "x")
let key2 = Defaults.Key<Bool>("observeKey2", default: true)
let expect = expectation(description: "Observation closure being called")

var observation: Defaults.Observation!
var counter = 0
observation = Defaults.observe(keys: key1, key2, options: []) {
counter += 1
if counter == 2 {
expect.fulfill()
} else if counter > 2 {
XCTFail()
}
}

Defaults[key1] = "y"
Defaults[key2] = false
observation.invalidate()

waitForExpectations(timeout: 10)
}

@available(iOS 11.0, macOS 10.13, tvOS 11.0, watchOS 4.0, iOSApplicationExtension 11.0, macOSApplicationExtension 10.13, tvOSApplicationExtension 11.0, watchOSApplicationExtension 4.0, *)
func testObserveMultipleNSSecureKeys() {
let key1 = Defaults.NSSecureCodingKey<ExamplePersistentHistory>("observeNSSecureCodingKey1", default: ExamplePersistentHistory(value: "TestValue"))
let key2 = Defaults.NSSecureCodingKey<ExamplePersistentHistory>("observeNSSecureCodingKey2", default: ExamplePersistentHistory(value: "TestValue"))
let expect = expectation(description: "Observation closure being called")

var observation: Defaults.Observation!
var counter = 0
observation = Defaults.observe(keys: key1, key2, options: []) {
counter += 1
if counter == 2 {
expect.fulfill()
} else if counter > 2 {
XCTFail()
}
}

Defaults[key1] = ExamplePersistentHistory(value: "NewTestValue1")
Defaults[key2] = ExamplePersistentHistory(value: "NewTestValue2")
observation.invalidate()

waitForExpectations(timeout: 10)
}

func testObserveKeyURL() {
let fixtureURL = URL(string: "https://sindresorhus.com")!
Expand Down Expand Up @@ -488,6 +535,53 @@ final class DefaultsTests: XCTestCase {

waitForExpectations(timeout: 10)
}

func testObservePreventPropagation() {
let key1 = Defaults.Key<Bool?>("preventPropagation1", default: nil)
let key2 = Defaults.Key<Bool?>("preventPropagation2", default: nil)
let expect = expectation(description: "No infinite recursion")

var observation: Defaults.Observation!
var wasInside = false
observation = Defaults.observe(keys: key1, key2, options: [], preventPropagation: true) {
XCTAssertFalse(wasInside)
wasInside = true
Defaults[key1] = true
expect.fulfill()
}

Defaults[key1] = false
observation.invalidate()

waitForExpectations(timeout: 10)
}

func testObservePreventPropagationMultiThread() {
let key1 = Defaults.Key<Int?>("preventPropagation3", default: nil)
let expect = expectation(description: "No infinite recursion")

var observation: Defaults.Observation!
// This checks if callback is still being called, if value is changed on second thread,
// while initial thread is doing some long lasting task.
observation = Defaults.observe(keys: key1, options: [], preventPropagation: true) {
Defaults[key1]! += 1
print("--- Main Thread: \(Thread.isMainThread)")
if !Thread.isMainThread {
XCTAssert(Defaults[key1]! == 4)
expect.fulfill()
} else {
usleep(100000)
print("--- Release: \(Thread.isMainThread)")
}
}
DispatchQueue.global().asyncAfter(deadline: .now() + 0.05) {
Defaults[key1]! += 1
}
Defaults[key1] = 1
observation.invalidate()

waitForExpectations(timeout: 10)
}

func testResetKey() {
let defaultFixture1 = "foo1"
Expand Down