File indexing completed on 2025-08-05 08:13:33
0001
0002 import numpy as np
0003
0004 import matplotlib.pyplot as plt
0005 from matplotlib import style
0006 style.use("ggplot")
0007
0008 from sklearn import svm
0009
0010 import itertools
0011
0012
0013 csv_e = np.genfromtxt ('testdata/set001/CalorimeterResponseHijingCentralRapidity_e_8GeV.csv')
0014 csv_pi = np.genfromtxt ('testdata/set001/CalorimeterResponseHijingCentralRapidity_pi_8GeV.csv')
0015
0016
0017 csv_e[:,1] = 0
0018 csv_pi[:,1] = 0
0019
0020
0021 data_e_sub = csv_e[:,2:5]
0022 data_pi_sub = csv_pi[:,2:5]
0023
0024
0025 data_e_train = data_e_sub[0:500,:]
0026 data_e_test = data_e_sub[500:1000,:]
0027
0028 data_pi_train = data_pi_sub[0:500,:]
0029 data_pi_test = data_pi_sub[500:1000,:]
0030
0031
0032 data_train = np.vstack((data_e_train,data_pi_train))
0033 data_test = np.vstack((data_pi_test,data_e_test))
0034
0035
0036 id_pi = 0
0037 id_e = 1
0038 pid_train = np.vstack(( np.ones((500,1)) , np.zeros((500,1)) ))
0039 pid_test = np.vstack(( np.zeros((500,1)) , np.ones((500,1)) ))
0040
0041
0042 X = data_train[:,1:3]
0043 y = np.ravel(pid_train)
0044 h = .001
0045
0046
0047 C = 1.0
0048 svc = svm.SVC(kernel='linear', C=C).fit(X, y)
0049 rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C).fit(X, y)
0050 poly_svc = svm.SVC(kernel='poly', degree=3, C=C).fit(X, y)
0051 lin_svc = svm.LinearSVC(C=C).fit(X, y)
0052
0053
0054 x_min, x_max = X[:, 0].min() - 0.01, X[:, 0].max() + 0.01
0055 y_min, y_max = X[:, 1].min() - 0.01, X[:, 1].max() + 0.01
0056 xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
0057 np.arange(y_min, y_max, h))
0058
0059
0060 titles = ['SVC with linear kernel',
0061 'LinearSVC (linear kernel)',
0062 'SVC with RBF kernel',
0063 'SVC with polynomial (degree 3) kernel']
0064
0065
0066 for i, clf in enumerate((svc, lin_svc, rbf_svc, poly_svc)):
0067
0068
0069 plt.subplot(2, 2, i + 1)
0070 plt.subplots_adjust(wspace=0.4, hspace=0.4)
0071
0072 Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
0073
0074
0075 Z = Z.reshape(xx.shape)
0076 plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
0077
0078
0079 plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
0080 plt.xlabel('energy EMCAL')
0081 plt.ylabel('energy HCAL (in)')
0082 plt.xlim(xx.min(), xx.max())
0083 plt.ylim(yy.min(), yy.max())
0084 plt.xticks(())
0085 plt.yticks(())
0086 plt.title(titles[i])
0087
0088 plt.show()
0089
0090 for i, clf in enumerate((svc, lin_svc, rbf_svc, poly_svc)):
0091
0092 pid_predict = clf.predict( data_test[:,1:3] )
0093
0094
0095 count_all = 0
0096 count_electron_as_electron = 0
0097 count_pion_as_pion = 0
0098 count_electron_as_pion = 0
0099 count_pion_as_electron = 0
0100 count_true_electron = 0
0101 count_true_pion = 0
0102
0103 for pid_pred, pid_true in itertools.izip(pid_predict, pid_test):
0104
0105 count_all += 1
0106
0107 if (pid_pred == id_e) & (pid_true == id_e):
0108 count_electron_as_electron += 1
0109 count_true_electron += 1
0110 elif (pid_pred == id_pi) & (pid_true == id_pi):
0111 count_pion_as_pion += 1
0112 count_true_pion += 1
0113 elif (pid_pred == id_pi) & (pid_true == id_e):
0114 count_electron_as_pion += 1
0115 count_true_electron += 1
0116 elif (pid_pred == id_e) & (pid_true == id_pi):
0117 count_pion_as_electron += 1
0118 count_true_pion += 1
0119 else:
0120 pass
0121
0122 print "-----------------------------------"
0123 print clf
0124 print "-----------------------------------"
0125 print "Electrons identified as electrons: %d out of %d" % (count_electron_as_electron, count_true_electron)
0126 print "Pions identified as electrons: %d out of %d" % (count_pion_as_electron, count_true_pion)
0127 print "-----------------------------------"
0128 print "Pions identified as pions: %d out of %d" % (count_pion_as_pion, count_true_pion)
0129 print "Electrons identified as pions: %d out of %d" % (count_electron_as_pion, count_true_electron)
0130 print "-----------------------------------"
0131