Skip to content

Commit

Permalink
Async http client for pull/push (#95)
Browse files Browse the repository at this point in the history
* Use async http client for pull/push

* Don't update progress too frequently

* Removed unused variable

* Rebased after added tests
  • Loading branch information
fkorotkov authored May 20, 2022
1 parent fec8032 commit 35904dc
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 75 deletions.
63 changes: 63 additions & 0 deletions Package.resolved
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"pins" : [
{
"identity" : "async-http-client",
"kind" : "remoteSourceControl",
"location" : "https://github.com/swift-server/async-http-client",
"state" : {
"revision" : "24425989dadab6d6e4167174791a23d4e2a6d0c3",
"version" : "1.10.0"
}
},
{
"identity" : "dynamic",
"kind" : "remoteSourceControl",
Expand Down Expand Up @@ -27,6 +36,60 @@
"version" : "0.8.1"
}
},
{
"identity" : "swift-log",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-log.git",
"state" : {
"revision" : "5d66f7ba25daf4f94100e7022febf3c75e37a6c7",
"version" : "1.4.2"
}
},
{
"identity" : "swift-nio",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio.git",
"state" : {
"revision" : "124119f0bb12384cef35aa041d7c3a686108722d",
"version" : "2.40.0"
}
},
{
"identity" : "swift-nio-extras",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-extras.git",
"state" : {
"revision" : "8eea84ec6144167354387ef9244b0939f5852dc8",
"version" : "1.11.0"
}
},
{
"identity" : "swift-nio-http2",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-http2.git",
"state" : {
"revision" : "72bcaf607b40d7c51044f65b0f5ed8581a911832",
"version" : "1.21.0"
}
},
{
"identity" : "swift-nio-ssl",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-ssl.git",
"state" : {
"revision" : "1750873bce84b4129b5303655cce2c3d35b9ed3a",
"version" : "2.19.0"
}
},
{
"identity" : "swift-nio-transport-services",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-transport-services.git",
"state" : {
"revision" : "1a4692acb88156e3da1b0c6732a8a38b2a744166",
"version" : "1.12.0"
}
},
{
"identity" : "swift-parsing",
"kind" : "remoteSourceControl",
Expand Down
6 changes: 4 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// swift-tools-version:5.6

import PackageDescription

let package = Package(
name: "Tart",
platforms: [
Expand All @@ -12,15 +11,18 @@ let package = Package(
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.1.2"),
.package(url: "https://github.com/pointfreeco/swift-parsing", from: "0.9.2"),
.package(url: "https://github.com/mhdhejazi/Dynamic", branch: "master"),
.package(url: "https://github.com/pointfreeco/swift-parsing", from: "0.9.2"),
.package(url: "https://github.com/swift-server/async-http-client", from: "1.10.0"),
],
targets: [
.executableTarget(name: "tart", dependencies: [
.product(name: "ArgumentParser", package: "swift-argument-parser"),
.product(name: "AsyncHTTPClient", package: "async-http-client"),
.product(name: "Dynamic", package: "Dynamic"),
.product(name: "Parsing", package: "swift-parsing"),
]),
.testTarget(name: "TartTests", dependencies: ["tart"])
]
)

7 changes: 6 additions & 1 deletion Sources/tart/Logging/ProgressObserver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Foundation
public class ProgressObserver: NSObject {
@objc var progressToObserve: Progress
var observation: NSKeyValueObservation?
var lastTimeUpdated = Date.now

public init(_ progress: Progress) {
progressToObserve = progress
Expand All @@ -11,7 +12,11 @@ public class ProgressObserver: NSObject {
func log(_ renderer: Logger) {
renderer.appendNewLine(ProgressObserver.lineToRender(progressToObserve))
observation = observe(\.progressToObserve.fractionCompleted) { progress, _ in
renderer.updateLastLine(ProgressObserver.lineToRender(self.progressToObserve))
let currentTime = Date.now
if self.progressToObserve.isFinished || currentTime.timeIntervalSince(self.lastTimeUpdated) >= 1.0 {
self.lastTimeUpdated = currentTime
renderer.updateLastLine(ProgressObserver.lineToRender(self.progressToObserve))
}
}
}

Expand Down
141 changes: 85 additions & 56 deletions Sources/tart/OCI/Registry.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
import Foundation
import NIOCore
import NIOHTTP1
import AsyncHTTPClient

enum RegistryError: Error {
case UnexpectedHTTPStatusCode(when: String, code: Int, details: String = "")
case UnexpectedHTTPStatusCode(when: String, code: UInt, details: String = "")
case MissingLocationHeader
case AuthFailed(why: String, details: String = "")
case MalformedHeader(why: String)
}

extension HTTPClientResponse.Body {
func readTextResponse() async throws -> String? {
let data = try await readResponse()
return String(decoding: data, as: UTF8.self)
}

func readResponse() async throws -> Data {
var result = Data()
for try await part in self {
result.append(Data(buffer: part))
}
return result
}
}

struct TokenResponse: Decodable {
let defaultIssuedAt = Date()
let defaultExpiresIn = 60
Expand Down Expand Up @@ -46,14 +64,16 @@ struct TokenResponse: Decodable {
(issuedAt ?? defaultIssuedAt) + TimeInterval(expiresIn ?? defaultExpiresIn)
}
}

var isValid: Bool {
get {
Date() < tokenExpiresAt
}
}
}

fileprivate let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)

class Registry {
var baseURL: URL
var namespace: String
Expand All @@ -76,41 +96,43 @@ class Registry {
}

func ping() async throws {
let (_, response) = try await endpointRequest("GET", "/v2/")
if response.statusCode != 200 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "doing ping", code: response.statusCode)
let response = try await endpointRequest(.GET, "/v2/")
if response.status != .ok {
throw RegistryError.UnexpectedHTTPStatusCode(when: "doing ping", code: response.status.code)
}
}

func pushManifest(reference: String, manifest: OCIManifest) async throws -> String {
let manifestJSON = try JSONEncoder().encode(manifest)

let (responseData, response) = try await endpointRequest("PUT", "\(namespace)/manifests/\(reference)",
let response = try await endpointRequest(.PUT, "\(namespace)/manifests/\(reference)",
headers: ["Content-Type": manifest.mediaType],
body: manifestJSON)
if response.statusCode != 201 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing manifest", code: response.statusCode,
details: String(decoding: responseData, as: UTF8.self))
if response.status != .created {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing manifest", code: response.status.code,
details: try await response.body.readTextResponse() ?? "")
}

return Digest.hash(manifestJSON)
}

public func pullManifest(reference: String) async throws -> (OCIManifest, Data) {
let (responseData, response) = try await endpointRequest("GET", "\(namespace)/manifests/\(reference)",
let response = try await endpointRequest(.GET, "\(namespace)/manifests/\(reference)",
headers: ["Accept": ociManifestMediaType])
if response.statusCode != 200 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling manifest", code: response.statusCode,
details: String(decoding: responseData, as: UTF8.self))
if response.status != .ok {
let body = try await response.body.readTextResponse()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling manifest", code: response.status.code,
details: body ?? "")
}

let manifest = try JSONDecoder().decode(OCIManifest.self, from: responseData)
let manifestData = try await response.body.readResponse()
let manifest = try JSONDecoder().decode(OCIManifest.self, from: manifestData)

return (manifest, responseData)
return (manifest, manifestData)
}

private func uploadLocationFromResponse(response: HTTPURLResponse) throws -> URLComponents {
guard let uploadLocationRaw = response.value(forHTTPHeaderField: "Location") else {
private func uploadLocationFromResponse(_ response: HTTPClientResponse) throws -> URLComponents {
guard let uploadLocationRaw = response.headers.first(name: "Location") else {
throw RegistryError.MissingLocationHeader
}

Expand All @@ -123,15 +145,16 @@ class Registry {

public func pushBlob(fromData: Data, chunkSize: Int = 5 * 1024 * 1024) async throws -> String {
// Initiate a blob upload
let (postData, postResponse) = try await endpointRequest("POST", "\(namespace)/blobs/uploads/",
let postResponse = try await endpointRequest(.POST, "\(namespace)/blobs/uploads/",
headers: ["Content-Length": "0"])
if postResponse.statusCode != 202 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (POST)", code: postResponse.statusCode,
details: String(decoding: postData, as: UTF8.self))
if postResponse.status != .accepted {
let body = try await postResponse.body.readTextResponse()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (POST)", code: postResponse.status.code,
details: body ?? "")
}

// Figure out where to upload the blob
let uploadLocation = try uploadLocationFromResponse(response: postResponse)
let uploadLocation = try uploadLocationFromResponse(postResponse)

// Upload the blob
let headers = [
Expand All @@ -144,46 +167,50 @@ class Registry {
"digest": digest,
]

let (putData, putResponse) = try await rawRequest("PUT", uploadLocation, headers: headers, parameters: parameters,
let putResponse = try await rawRequest(.PUT, uploadLocation, headers: headers, parameters: parameters,
body: fromData)
if putResponse.statusCode != 201 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (PUT)", code: putResponse.statusCode,
details: String(decoding: putData, as: UTF8.self))
if putResponse.status != .created {
let body = try await postResponse.body.readTextResponse()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (PUT)", code: putResponse.status.code,
details: body ?? "")
}

return digest
}

public func pullBlob(_ digest: String) async throws -> Data {
let (putData, putResponse) = try await endpointRequest("GET", "\(namespace)/blobs/\(digest)")
if putResponse.statusCode != 200 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling blob", code: putResponse.statusCode,
details: String(decoding: putData, as: UTF8.self))
public func pullBlob(_ digest: String, handler: (ByteBuffer) throws -> Void) async throws {
let response = try await endpointRequest(.GET, "\(namespace)/blobs/\(digest)")
if response.status != .ok {
let body = try await response.body.readTextResponse()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling blob", code: response.status.code,
details: body ?? "")
}

return putData
for try await part in response.body {
try handler(part)
}
}

private func endpointRequest(
_ method: String,
_ method: HTTPMethod,
_ endpoint: String,
headers: Dictionary<String, String> = Dictionary(),
parameters: Dictionary<String, String> = Dictionary(),
body: Data? = nil
) async throws -> (Data, HTTPURLResponse) {
) async throws -> HTTPClientResponse {
let url = URL(string: endpoint, relativeTo: baseURL)!
let urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: true)!

return try await rawRequest(method, urlComponents, headers: headers, parameters: parameters, body: body)
}

private func rawRequest(
_ method: String,
_ method: HTTPMethod,
_ urlComponents: URLComponents,
headers: Dictionary<String, String> = Dictionary(),
parameters: Dictionary<String, String> = Dictionary(),
body: Data? = nil
) async throws -> (Data, HTTPURLResponse) {
) async throws -> HTTPClientResponse {
var urlComponents = urlComponents

if urlComponents.queryItems == nil {
Expand All @@ -193,31 +220,33 @@ class Registry {
URLQueryItem(name: key, value: value)
})

var request = URLRequest(url: urlComponents.url!)
request.httpMethod = method
var request = HTTPClientRequest(url: urlComponents.string!)
request.method = method
for (key, value) in headers {
request.addValue(value, forHTTPHeaderField: key)
request.headers.add(name: key, value: value)
}
if body != nil {
request.body = HTTPClientRequest.Body.bytes(body!)
}
request.httpBody = body

// Invalidate token if it has expired
if currentAuthToken?.isValid == false {
currentAuthToken = nil
}

var (data, response) = try await authAwareRequest(request: request)
var response = try await authAwareRequest(request: request)

if response.statusCode == 401 {
if response.status == .unauthorized {
try await auth(response: response)
(data, response) = try await authAwareRequest(request: request)
response = try await authAwareRequest(request: request)
}

return (data, response)
return response
}

private func auth(response: HTTPURLResponse) async throws {
private func auth(response: HTTPClientResponse) async throws {
// Process WWW-Authenticate header
guard let wwwAuthenticateRaw = response.value(forHTTPHeaderField: "WWW-Authenticate") else {
guard let wwwAuthenticateRaw = response.headers.first(name: "WWW-Authenticate") else {
throw RegistryError.AuthFailed(why: "got HTTP 401, but WWW-Authenticate header is missing")
}

Expand Down Expand Up @@ -257,24 +286,24 @@ class Registry {
headers["Authorization"] = "Basic \(encodedCredentials!)"
}

let (tokenResponseRaw, response) = try await rawRequest("GET", authenticateURL, headers: headers)
if response.statusCode != 200 {
throw RegistryError.AuthFailed(why: "received unexpected HTTP status code \(response.statusCode) "
+ "while retrieving an authentication token", details: String(decoding: tokenResponseRaw, as: UTF8.self))
let response = try await rawRequest(.GET, authenticateURL, headers: headers)
if response.status != .ok {
let body = try await response.body.readTextResponse() ?? ""
throw RegistryError.AuthFailed(why: "received unexpected HTTP status code \(response.status.code) "
+ "while retrieving an authentication token", details: body)
}

currentAuthToken = try TokenResponse.parse(fromData: tokenResponseRaw)
let bodyData = try await response.body.readResponse()
currentAuthToken = try TokenResponse.parse(fromData: bodyData)
}

private func authAwareRequest(request: URLRequest) async throws -> (Data, HTTPURLResponse) {
private func authAwareRequest(request: HTTPClientRequest) async throws -> HTTPClientResponse {
var request = request

if let token = currentAuthToken {
request.addValue("Bearer \(token.token)", forHTTPHeaderField: "Authorization")
request.headers.add(name: "Authorization", value: "Bearer \(token.token)")
}

let (responseData, response) = try await URLSession.shared.data(for: request)

return (responseData, response as! HTTPURLResponse)
return try await httpClient.execute(request, deadline: .distantFuture)
}
}
Loading

0 comments on commit 35904dc

Please sign in to comment.