Technology Sharing

57. Classification based on probabilistic neural network (PNN) (matlab)

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

1. Introduction to classification based on probabilistic neural network (PNN)

PNN (Probabilistic Neural Network) is a neural network model based on probability theory, mainly used to solve classification problems. PNN was first proposed by Makovsky and Masikin in 1993 and is a very effective classification algorithm.

The principle of PNN can be simply summarized into the following steps:

  1. Data input layer: Input the sample data into the model respectively.
  2. Mode layer:Perform pattern matching on each input data and calculate the similarity score between it and the specified category.
  3. Mode comparison layer:The similarity scores of all categories are compared, and the category with the highest score is found as the final classification result.

PNN has the following characteristics:

  1. Efficiency: PNN has a fast training speed and shows high classification accuracy in practical applications.
  2. robustness:PNN is highly robust to noise and outliers and can effectively handle complex classification problems.
  3. Easy to explain:The results of PNN can be explained intuitively, allowing users to better understand the classification basis of the model.

In general, PNN is a very effective classification algorithm, which is suitable for classification problems in various fields, such as image recognition, text classification, etc.

2. Classification description and key functions based on probabilistic neural network (PNN)

1) Description

Here there are three binary input vectors X and their associated classes Tc.
Create y probabilistic neural network that correctly classifies these vectors.

2) Important functions

newpnn() function:Designing a probabilistic neural network

Probabilistic Neural Network (PNN) is a radial basis network suitable for classification problems.

grammar

net = newpnn(P,T,spread)% Accepts two or three parameters and returns a new probabilistic neural network.

parameter

P:r × Q matrix of Q input vectors

T:s × Q matrix of Q target class vectors

spread: Spread of radial basis functions (default = 0.1)

If the spread is close to zero, the network acts as a nearest neighbor classifier. When the spread becomes larger, the designed network considers several nearby design vectors.

sim() function: simulate neural network

grammar

[Y,Xf,Af] = sim(net,X,Xi,Ai,T) 
parameter

net: network

X: Network input

Xi: Initial input delay condition (default = 0)

Ai: Initial layer delay condition (default = 0)

T: Network target (default = 0)

3. Dataset and display

Code

  1. X = [1 2; 2 2; 1 1]';
  2. Tc = [1 2 3];
  3. figure(1)
  4. plot(X(1,:),X(2,:),'.','markersize',30)
  5. for i = 1:3, text(X(1,i)+0.1,X(2,i),sprintf('class %g',Tc(i))), end
  6. axis([0 3 0 3])
  7. title('三个二元向量及分类')
  8. xlabel('X(1,:)')
  9. ylabel('X(2,:)')

View Effects

4d5fefa895b64a729f35c0ae3f875dfa.png

4. Test the network based on the designed input vector

1) Description

Convert the target class index Tc to a vector T
Designing Probabilistic Neural Networks with NEWPNN
SPREAD value is 1 because this is the typical distance in y between the input vectors.

2) Test network

Code

  1. T = ind2vec(Tc);
  2. spread = 1;
  3. net = newpnn(X,T,spread);
  4. %测试网络
  5. %基于输入向量测试网络。通过对网络进行仿真并将其向量输出转换为索引来实现目的。
  6. Y = net(X);
  7. Yc = vec2ind(Y);
  8. figure(2)
  9. plot(X(1,:),X(2,:),'.','markersize',30)
  10. axis([0 3 0 3])
  11. for i = 1:3,text(X(1,i)+0.1,X(2,i),sprintf('class %g',Yc(i))),end
  12. title('测试网络')
  13. xlabel('X(1,:)')
  14. ylabel('X(2,:)')

View Effects

99dcaabffb554913b0d153b807c6b7e2.png

3) Test the network with new data

Code

  1. x = [2; 1.5];
  2. y = net(x);
  3. ac = vec2ind(y);
  4. hold on
  5. figure(3)
  6. plot(x(1),x(2),'.','markersize',30,'color',[1 0 0])
  7. text(x(1)+0.1,x(2),sprintf('class %g',ac))
  8. hold off
  9. title('新数据分类')
  10. xlabel('X(1,:) and x(1)')
  11. ylabel('X(2,:) and x(2)')

View Effects

03a46217cb4840ddb1d7a4da2b249be4.png

5. The probabilistic neural network divides the input space into three classes.

illustrate

Divided into three categories

Code

  1. x1 = 0:.05:3;
  2. x2 = x1;
  3. [X1,X2] = meshgrid(x1,x2);
  4. xx = [X1(:) X2(:)]';
  5. yy = net(xx);
  6. yy = full(yy);
  7. m = mesh(X1,X2,reshape(yy(1,:),length(x1),length(x2)));
  8. m.FaceColor = [0 0.5 1];
  9. m.LineStyle = 'none';
  10. hold on
  11. m = mesh(X1,X2,reshape(yy(2,:),length(x1),length(x2)));
  12. m.FaceColor = [0 1.0 0.5];
  13. m.LineStyle = 'none';
  14. m = mesh(X1,X2,reshape(yy(3,:),length(x1),length(x2)));
  15. m.FaceColor = [0.5 0 1];
  16. m.LineStyle = 'none';
  17. plot3(X(1,:),X(2,:),[1 1 1]+0.1,'.','markersize',30)
  18. plot3(x(1),x(2),1.1,'.','markersize',30,'color',[1 0 0])
  19. hold off
  20. view(2)
  21. title('三分类')
  22. xlabel('X(1,:) and x(1)')
  23. ylabel('X(2,:) and x(2)')

Try the effect

d9f88e5a4a194cdab4a7d4c903d77645.png

6. Summary

Probabilistic neural network (PNN) is an artificial neural network used for pattern classification. It is based on Bayes' theorem and Gaussian mixture model and can be used to process various types of data, including continuous data and discrete data. PNN is more flexible than traditional neural networks in handling classification problems and has higher accuracy and generalization ability.

The basic working principle of PNN is to calculate the similarity between the input data set and each sample in the sample set, and classify the input data according to the similarity. PNN consists of four layers: input layer, pattern layer, competition layer and output layer. The input data is first passed to the pattern layer through the input layer, then the similarity is calculated through the competition layer, and finally classified in the output layer according to the similarity.

In Matlab, you can use the relevant toolbox or program it yourself to implement PNN classification. First, you need to prepare a training data set and a test data set, and then train the PNN model with the training data set. After the training is completed, you can use the test data set to evaluate the classification performance of the PNN and make predictions.

In general, PNN is a powerful classification method that is suitable for various classification problems. In practical applications, appropriate features and model parameters can be selected according to specific problems to improve classification performance. Matlab provides a wealth of tools and function support, making the implementation and application of PNN more convenient.

7. Source code

Code

  1. %% 基于概率神经网络(PNN)的分类(matlab)
  2. %此处有三个二元输入向量 X 和它们相关联的类 Tc。
  3. %创建 y 概率神经网络,对这些向量正确分类。
  4. %重要函数:NEWPNN 和 SIM 函数
  5. %% 数据集及显示
  6. X = [1 2; 2 2; 1 1]';
  7. Tc = [1 2 3];
  8. figure(1)
  9. plot(X(1,:),X(2,:),'.','markersize',30)
  10. for i = 1:3, text(X(1,i)+0.1,X(2,i),sprintf('class %g',Tc(i))), end
  11. axis([0 3 0 3])
  12. title('三个二元向量及分类')
  13. xlabel('X(1,:)')
  14. ylabel('X(2,:)')
  15. %% 基于设计输入向量测试网络
  16. %将目标类索引 Tc 转换为向量 T
  17. %用 NEWPNN 设计 y 概率神经网络
  18. % SPREAD 值 1,因为这是输入向量之间的 y 典型距离。
  19. T = ind2vec(Tc);
  20. spread = 1;
  21. net = newpnn(X,T,spread);
  22. %测试网络
  23. %基于输入向量测试网络。通过对网络进行仿真并将其向量输出转换为索引来实现目的。
  24. Y = net(X);
  25. Yc = vec2ind(Y);
  26. figure(2)
  27. plot(X(1,:),X(2,:),'.','markersize',30)
  28. axis([0 3 0 3])
  29. for i = 1:3,text(X(1,i)+0.1,X(2,i),sprintf('class %g',Yc(i))),end
  30. title('测试网络')
  31. xlabel('X(1,:)')
  32. ylabel('X(2,:)')
  33. %数据测试
  34. x = [2; 1.5];
  35. y = net(x);
  36. ac = vec2ind(y);
  37. hold on
  38. figure(3)
  39. plot(x(1),x(2),'.','markersize',30,'color',[1 0 0])
  40. text(x(1)+0.1,x(2),sprintf('class %g',ac))
  41. hold off
  42. title('新数据分类')
  43. xlabel('X(1,:) and x(1)')
  44. ylabel('X(2,:) and x(2)')
  45. %% 概率神经网络将输入空间分为三个类。
  46. x1 = 0:.05:3;
  47. x2 = x1;
  48. [X1,X2] = meshgrid(x1,x2);
  49. xx = [X1(:) X2(:)]';
  50. yy = net(xx);
  51. yy = full(yy);
  52. m = mesh(X1,X2,reshape(yy(1,:),length(x1),length(x2)));
  53. m.FaceColor = [0 0.5 1];
  54. m.LineStyle = 'none';
  55. hold on
  56. m = mesh(X1,X2,reshape(yy(2,:),length(x1),length(x2)));
  57. m.FaceColor = [0 1.0 0.5];
  58. m.LineStyle = 'none';
  59. m = mesh(X1,X2,reshape(yy(3,:),length(x1),length(x2)));
  60. m.FaceColor = [0.5 0 1];
  61. m.LineStyle = 'none';
  62. plot3(X(1,:),X(2,:),[1 1 1]+0.1,'.','markersize',30)
  63. plot3(x(1),x(2),1.1,'.','markersize',30,'color',[1 0 0])
  64. hold off
  65. view(2)
  66. title('三分类')
  67. xlabel('X(1,:) and x(1)')
  68. ylabel('X(2,:) and x(2)')