From 98487c560f175e5e0c6c6d5aeb2b4b5bdd73a1e6 Mon Sep 17 00:00:00 2001
From: Cyrille <cyrille@bayesimpact.org>
Date: Thu, 3 Sep 2020 21:56:25 +0200
Subject: [PATCH] Add batch creation API.

---
 .travis.yml               |  1 +
 airtable/airtable.py      |  5 ++++-
 airtable/airtable.pyi     |  8 +++++++-
 airtable/airtable_test.py | 36 ++++++++++++++++++++++++++++++++++++
 4 files changed, 48 insertions(+), 2 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index 849f27d..5b7da89 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,6 +1,7 @@
 language: python
 python:
   - 2.7
+  - 3.8
 install:
   - pip install -r requirements.txt
 script: python airtable/airtable_test.py
diff --git a/airtable/airtable.py b/airtable/airtable.py
index caab139..4528171 100644
--- a/airtable/airtable.py
+++ b/airtable/airtable.py
@@ -137,7 +137,10 @@ def iterate(
 
     def create(self, table_name, data):
         if check_string(table_name):
-            payload = create_payload(data)
+            if isinstance(data, list):
+                payload = [create_payload(record) for record in data]
+            else:
+                payload = create_payload(data)
             return self.__request('POST', table_name,
                                   payload=json.dumps(payload))
 
diff --git a/airtable/airtable.pyi b/airtable/airtable.pyi
index 22d026e..4659275 100644
--- a/airtable/airtable.pyi
+++ b/airtable/airtable.pyi
@@ -61,9 +61,15 @@ class Airtable(object):
             -> _Record:
         ...
 
+    @typing.overload
     def create(self, table_name: str, data: typing.Dict[str, typing.Any]) -> _Record:
         ...
- 
+
+    @typing.overload
+    def create(self, table_name: str, data: typing.List[typing.Dict[str, typing.Any]]) \
+            -> typing.List[_Record]:
+        ...
+
     def update(self, table_name: str, record_id: str, data: typing.Dict[str, typing.Any]) -> _Record:
         ...
 
diff --git a/airtable/airtable_test.py b/airtable/airtable_test.py
index 41ba115..710b88b 100644
--- a/airtable/airtable_test.py
+++ b/airtable/airtable_test.py
@@ -1,5 +1,6 @@
 
 import airtable
+import json
 import mock
 import requests
 import unittest
@@ -152,5 +153,40 @@ def test_invalid_delete(self):
         with self.assertRaises(airtable.IsNotString):
             self.airtable.delete(FAKE_TABLE_NAME, 123)
 
+    @mock.patch.object(requests, 'request')
+    def test_create(self, mock_request):
+        record = {
+            'field1': 'value1',
+            'field2': 'value2',
+        }
+        self.airtable.create(FAKE_TABLE_NAME, record)
+        mock_request.assert_called_once()
+        unused_args, kwargs = mock_request.call_args
+        sent_data = json.loads(kwargs['data'])
+        self.assertEqual(
+            {'fields': {'field1': 'value1', 'field2': 'value2'}},
+            sent_data)
+
+    @mock.patch.object(requests, 'request')
+    def test_batch_create(self, mock_request):
+        records = [
+            {
+                'field1': 'value1',
+                'field2': 'value2',
+            },
+            {
+                'field1': 'value3',
+                'field2': 'value4',
+            },
+        ]
+        self.airtable.create(FAKE_TABLE_NAME, records)
+        mock_request.assert_called_once()
+        unused_args, kwargs = mock_request.call_args
+        sent_data = json.loads(kwargs['data'])
+        self.assertEqual([
+            {'fields': {'field1': 'value1', 'field2': 'value2'}},
+            {'fields': {'field1': 'value3', 'field2': 'value4'}},
+        ], sent_data)
+
 if __name__ == '__main__':
     unittest.main()