Skip to content

Commit 6cfe6e0

Browse files
committedMay 17, 2018
sklearn bug fix
1 parent ffecee3 commit 6cfe6e0

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed
 

‎pyson_production.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,18 @@ def getR2 (Y, predY):
451451
if SStot !=0:
452452
R2 = 1 - (SSres/SStot)
453453
#R2_v2 = SSreg/SStot
454+
#print SSreg, SSres, SStot, "in" #### test
454455
else:
455456
if SSres !=0:
456457
R2 = 0
458+
#print SSreg, SSres, SStot, "out" #### test
457459
else:
458460
R2 = 1
459461

460462
if R2<0:
461463
R2 = 0
464+
R2 = float(R2) #### test
465+
#print R2, SSres, SSreg, SStot #### test
462466
#"""
463467
return R2
464468

@@ -483,15 +487,20 @@ def best_latent_variable(X, Y, latent_variables, num_instances):
483487

484488
#3
485489
if num_instances <10:
486-
predY = cross_validation.cross_val_predict( plsca, numpy.array(X), numpy.array(Y), cv=num_instances)
490+
#predY = cross_validation.cross_val_predict( plsca, numpy.array(X), numpy.array(Y), cv=num_instances) #### was
491+
predY = cross_validation.cross_val_predict( plsca, numpy.array(X), numpy.array(Y), cv=num_instances) #### test
487492
else:
488-
predY = cross_validation.cross_val_predict( plsca, numpy.array(X), numpy.array(Y), cv=10)
489-
r2 = getR2 (Y, predY)
493+
#predY = cross_validation.cross_val_predict( plsca, numpy.array(X), numpy.array(Y), cv=10) #### was
494+
predY = cross_validation.cross_val_predict( plsca, numpy.array(X), numpy.array(Y), cv=10) #### test
495+
predY = predY.flatten()
496+
497+
r2 = getR2(Y, predY)
490498

491499
if (r2 > r2_best):
492500
r2_best = r2
493501
lv_best = lv
494-
#print r2_best
502+
503+
#print r2_best, lv_best, "\n" #### test
495504
return lv_best
496505

497506
def get_vip (fin_pls, lv_best, current_attribute, attributes_gone, attributes):
@@ -570,6 +579,7 @@ def plsvip (X, Y, V, lat_var):
570579

571580
#print Y[0], predY[0]
572581
currentR2 = getR2 (Y, predY)
582+
#print currentR2, "damn" #### test
573583
#print "R2 ", currentR2, "Avg", numpy.mean(Y), "Pred", numpy.mean(predY), "Attr", attributes, "Lat", lv_best
574584

575585
min_vip = 1000
@@ -624,7 +634,7 @@ def bestpls(vipMatrix, X, Y, V):
624634
bestR2 = vipMatrix[0][1]
625635
lv_best = vipMatrix[0][3]
626636
position = 0
627-
637+
print bestR2, lv_best, "YOLO" #### test
628638
for entries in range (len(vipMatrix)):
629639

630640
if vipMatrix[entries][1] > bestR2:
@@ -742,7 +752,7 @@ def lm_test (variables, datapoints, predictionFeature, rawModel):
742752
clf2 = pickle.loads(decoded)
743753
predictionList = []
744754
for i in range (len(datapoints)):
745-
temp = clf2.predict(datapoints[i])
755+
temp = clf2.predict([datapoints[i]]) ##[]
746756
finalPrediction = {predictionFeature:temp[0]}
747757
predictionList.append(finalPrediction)
748758
return predictionList
@@ -766,7 +776,8 @@ def gnb_test (variables, datapoints, predictionFeature, rawModel):
766776
gnb2 = pickle.loads(decoded)
767777
predictionList = []
768778
for i in range (len(datapoints)):
769-
temp = gnb2.predict(datapoints[i])
779+
## temp = gnb2.predict(datapoints[i]) #R
780+
temp = gnb2.predict([datapoints[i]]) ##[]
770781
if isinstance (temp,str):
771782
finalPrediction = {predictionFeature:temp}
772783
else:
@@ -794,7 +805,7 @@ def mnb_test (variables, datapoints, predictionFeature, rawModel):
794805
mnb2 = pickle.loads(decoded)
795806
predictionList = []
796807
for i in range (len(datapoints)):
797-
temp = mnb2.predict(datapoints[i])
808+
temp = mnb2.predict([datapoints[i]]) ##[]
798809
if isinstance (temp,str):
799810
finalPrediction = {predictionFeature:temp}
800811
else:
@@ -822,7 +833,7 @@ def bnb_test (variables, datapoints, predictionFeature, rawModel):
822833
bnb2 = pickle.loads(decoded)
823834
predictionList = []
824835
for i in range (len(datapoints)):
825-
temp = bnb2.predict(datapoints[i])
836+
temp = bnb2.predict([datapoints[i]]) ##[]
826837
if isinstance (temp,str):
827838
finalPrediction = {predictionFeature:temp}
828839
else:
@@ -850,7 +861,7 @@ def lasso_test (variables, datapoints, predictionFeature, rawModel):
850861
clf2 = pickle.loads(decoded)
851862
predictionList = []
852863
for i in range (len(datapoints)):
853-
temp = clf2.predict(datapoints[i])
864+
temp = clf2.predict([datapoints[i]]) ##[]
854865
finalPrediction = {predictionFeature:temp[0]}
855866
predictionList.append(finalPrediction)
856867
return predictionList
@@ -1794,18 +1805,28 @@ def __call__(self, environ, start_response):
17941805
from cStringIO import StringIO
17951806
input = environ.get('wsgi.input')
17961807
length = environ.get('CONTENT_LENGTH', '0')
1808+
17971809
length = 0 if length == '' else int(length)
1810+
17981811
body = ''
17991812
if length == 0:
18001813
environ['body_copy'] = ''
18011814
if input is None:
18021815
return
1816+
18031817
if environ.get('HTTP_TRANSFER_ENCODING','0') == 'chunked':
1804-
size = int(input.readline(),16)
1805-
while size > 0:
1806-
temp = str(input.read(size+2)).strip()
1807-
body += temp
1808-
size = int(input.readline(),16)
1818+
while (1):
1819+
temp = input.readline() ##
1820+
1821+
if not temp:
1822+
break
1823+
body +=temp
1824+
size = len(body)
1825+
#if environ.get('HTTP_TRANSFER_ENCODING','0') == 'chunked':
1826+
# body = input.readline() ##
1827+
# print "\n\n\n\nBODY\n\n\n\n", body
1828+
# size = len(body)
1829+
18091830
else:
18101831
body = environ['wsgi.input'].read(length)
18111832
environ['body_copy'] = body
@@ -1833,4 +1854,5 @@ def callback(status, headers, exc_info=None):
18331854
app.wsgi_app = WSGICopyBody(app.wsgi_app) ##
18341855
app.run(host="0.0.0.0", port = 5000, debug = True)
18351856

1836-
#curl -i -H "Content-Type: application/json" -X POST -d @C:/Python27/Flask-0.10.1/python-api/vipbugtrain.json http://localhost:5000/pws/vip/train
1857+
## curl -i -H "Transfer-encoding:chunked" -H "Content-Type:application/json" -X POST -d @C:/Python27-15/trainCategorical.json http://localhost:5000/pws/gnb/train
1858+
# curl -i -H "Transfer-encoding:chunked" -H "Content-Type:application/json" -X POST -d @C:/Python27-15/trainCategorical.json http://192.168.99.100:5000/pws/gnb/train

0 commit comments

Comments
 (0)
Please sign in to comment.