Copied remotely
[mono.git] / mcs / class / ByteFX.Data / mysqlclient / Driver.cs
1 // ByteFX.Data data access components for .Net\r
2 // Copyright (C) 2002-2003  ByteFX, Inc.\r
3 //\r
4 // This library is free software; you can redistribute it and/or\r
5 // modify it under the terms of the GNU Lesser General Public\r
6 // License as published by the Free Software Foundation; either\r
7 // version 2.1 of the License, or (at your option) any later version.\r
8 // \r
9 // This library is distributed in the hope that it will be useful,\r
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of\r
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
12 // Lesser General Public License for more details.\r
13 // \r
14 // You should have received a copy of the GNU Lesser General Public\r
15 // License along with this library; if not, write to the Free Software\r
16 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA\r
17 \r
18 using System;\r
19 using System.Net;\r
20 using System.Net.Sockets;\r
21 using System.IO;\r
22 using ICSharpCode.SharpZipLib.Zip.Compression;\r
23 using ICSharpCode.SharpZipLib.Zip.Compression.Streams;\r
24 using System.Security.Cryptography;\r
25 using ByteFX.Data.Common;\r
26 using System.Collections;\r
27 using System.Text;\r
28 \r
29 namespace ByteFX.Data.MySqlClient\r
30 {\r
31         /// <summary>\r
32         /// Summary description for Driver.\r
33         /// </summary>\r
34         internal class Driver\r
35         {\r
36                 protected const int HEADER_LEN = 4;\r
37                 protected const int MIN_COMPRESS_LENGTH = 50;\r
38                 protected const int MAX_PACKET_SIZE = 256*256*256-1;\r
39 \r
40                 protected Stream                        stream;\r
41                 protected BufferedStream        writer;\r
42                 protected Encoding                      encoding;\r
43                 protected byte                          packetSeq;\r
44                 protected long                          maxPacketSize;\r
45                 protected DBVersion                     serverVersion;\r
46                 protected bool                          isOpen;\r
47                 protected string                        versionString;\r
48                 protected Packet                        peekedPacket;\r
49 \r
50                 protected int                           protocol;\r
51                 protected uint                          threadID;\r
52                 protected String                        encryptionSeed;\r
53                 protected int                           serverCaps;\r
54                 protected bool                          useCompression = false;\r
55 \r
56 \r
57                 public Driver()\r
58                 {\r
59                         packetSeq = 0;\r
60                         encoding = System.Text.Encoding.Default;\r
61                         isOpen = false;\r
62                 }\r
63 \r
64                 public Encoding Encoding \r
65                 {\r
66                         get { return encoding; }\r
67                         set { encoding = value; }\r
68                 }\r
69 \r
70                 public long MaxPacketSize \r
71                 {\r
72                         get { return maxPacketSize; }\r
73                         set { maxPacketSize = value; }\r
74                 }\r
75 \r
76                 public string VersionString \r
77                 {\r
78                         get { return versionString; }\r
79                 }\r
80 \r
81                 public DBVersion Version \r
82                 {\r
83                         get { return serverVersion; }\r
84                 }\r
85 \r
86                 public void Open( MySqlConnectionString settings )\r
87                 {\r
88                         // connect to one of our specified hosts\r
89                         try \r
90                         {\r
91                                 StreamCreator sc = new StreamCreator( settings.Server, settings.Port, settings.PipeName );\r
92                                 stream = sc.GetStream( settings.ConnectionTimeout );\r
93                         }\r
94                         catch (Exception ex)\r
95                         {\r
96                                 throw new MySqlException("Unable to connect to any of the specified MySQL hosts", ex);\r
97                         }\r
98 \r
99                         if (stream == null) \r
100                                 throw new MySqlException("Unable to connect to any of the specified MySQL hosts");\r
101 \r
102                         writer = new BufferedStream( stream );\r
103                         // read off the welcome packet and parse out it's values\r
104                         Packet packet = ReadPacket();\r
105                         protocol = packet.ReadByte();\r
106                         versionString = packet.ReadString();\r
107                         serverVersion = DBVersion.Parse( versionString );\r
108                         threadID = (uint)packet.ReadInteger(4);\r
109                         encryptionSeed = packet.ReadString();\r
110 \r
111                         // read in Server capabilities if they are provided\r
112                         serverCaps = 0;\r
113                         if (packet.HasMoreData)\r
114                                 serverCaps = (int)packet.ReadInteger(2);\r
115 \r
116                         Authenticate( settings.UserId, settings.Password, settings.UseCompression );\r
117 \r
118                         // if we are using compression, then we use our CompressedStream class\r
119                         // to hide the ugliness of managing the compression\r
120                         if (settings.UseCompression)\r
121                         {\r
122                                 stream = new CompressedStream( stream );\r
123                                 writer = new BufferedStream( stream );\r
124                         }\r
125 \r
126                         isOpen = true;\r
127                 }\r
128 \r
129                 private Packet CreatePacket( byte[] buf )\r
130                 {\r
131                         if (buf == null)\r
132                                 return new Packet( serverVersion.isAtLeast(3, 22, 5) );\r
133                         return new Packet( buf, serverVersion.isAtLeast(3, 22, 5 ));\r
134                 }\r
135 \r
136                 private void Authenticate( String userid, String password, bool UseCompression )\r
137                 {\r
138                         ClientParam clientParam = ClientParam.CLIENT_FOUND_ROWS | ClientParam.CLIENT_LONG_FLAG;\r
139 \r
140                         if ((serverCaps & (int)ClientParam.CLIENT_COMPRESS) != 0 && UseCompression)\r
141                         {\r
142                                 clientParam |= ClientParam.CLIENT_COMPRESS;\r
143                         }\r
144 \r
145                         clientParam |= ClientParam.CLIENT_LONG_PASSWORD;\r
146                         clientParam |= ClientParam.CLIENT_LOCAL_FILES;\r
147 //                      if (serverVersion.isAtLeast(4,1,0))\r
148 //                              clientParam |= ClientParam.CLIENT_PROTOCOL_41;\r
149 //                      if ( (serverCaps & (int)ClientParam.CLIENT_SECURE_CONNECTION ) != 0 && password.Length > 0 )\r
150 //                              clientParam |= ClientParam.CLIENT_SECURE_CONNECTION;\r
151 \r
152                         int packetLength = userid.Length + 16 + 6 + 4;  // Passwords can be 16 chars long\r
153 \r
154                         Packet packet = CreatePacket(null);\r
155 \r
156                         if ((clientParam & ClientParam.CLIENT_PROTOCOL_41) != 0)\r
157                         {\r
158                                 packet.WriteInteger( (int)clientParam, 4 );\r
159                                 packet.WriteInteger( (256*256*256)-1, 4 );\r
160                         }\r
161                         else\r
162                         {\r
163                                 packet.WriteInteger( (int)clientParam, 2 );\r
164                                 packet.WriteInteger( 255*255*255, 3 );\r
165                         }\r
166 \r
167                         packet.WriteString( userid, encoding  );\r
168                         if ( (clientParam & ClientParam.CLIENT_SECURE_CONNECTION ) != 0 )\r
169                         {\r
170                                 // use the new authentication system\r
171                                 AuthenticateSecurely( packet, password );\r
172                         }\r
173                         else\r
174                         {\r
175                                 // use old authentication system\r
176                                 packet.WriteString( EncryptPassword(password, encryptionSeed, protocol > 9), encoding );\r
177                                 // pad zeros out to packetLength for auth\r
178                                 for (int i=0; i < (packetLength-packet.Length); i++)\r
179                                         packet.WriteByte(0);\r
180                                 SendPacket(packet);\r
181                         }\r
182 \r
183                         packet = ReadPacket();\r
184                         if ((clientParam & ClientParam.CLIENT_COMPRESS) != 0)\r
185                                 useCompression = true;\r
186                 }\r
187 \r
188                 /// <summary>\r
189                 /// AuthenticateSecurity implements the new 4.1 authentication scheme\r
190                 /// </summary>\r
191                 /// <param name="packet">The in-progress packet we use to complete the authentication</param>\r
192                 /// <param name="password">The password of the user to use</param>\r
193                 private void AuthenticateSecurely( Packet packet, string password )\r
194                 {\r
195                         packet.WriteString("xxxxxxxx", encoding );\r
196                         SendPacket(packet);\r
197 \r
198                         packet = ReadPacket();\r
199 \r
200                         // compute pass1 hash\r
201                         string newPass = password.Replace(" ","").Replace("\t","");\r
202                         SHA1 sha = new SHA1CryptoServiceProvider(); \r
203                         byte[] firstPassBytes = sha.ComputeHash( System.Text.Encoding.Default.GetBytes(newPass));\r
204 \r
205                         byte[] salt = packet.GetBuffer();\r
206                         byte[] input = new byte[ firstPassBytes.Length + 4 ];\r
207                         salt.CopyTo( input, 0 );\r
208                         firstPassBytes.CopyTo( input, 4 );\r
209                         byte[] outPass = new byte[100];\r
210                         byte[] secondPassBytes = sha.ComputeHash( input );\r
211 \r
212                         byte[] cryptSalt = new byte[20];\r
213                         Security.ArrayCrypt( salt, 4, cryptSalt, 0, secondPassBytes, 20 );\r
214 \r
215                         Security.ArrayCrypt( cryptSalt, 0, firstPassBytes, 0, firstPassBytes, 20 );\r
216 \r
217                         // send the packet\r
218                         packet = CreatePacket(null);\r
219                         packet.Write( firstPassBytes, 0, 20 );\r
220                         SendPacket(packet);\r
221                 }\r
222 \r
223 \r
224                 /// <summary>\r
225                 /// \r
226                 /// </summary>\r
227                 /// <returns></returns>\r
228                 public Packet PeekPacket()\r
229                 {\r
230                         if (peekedPacket != null)\r
231                                 return peekedPacket;\r
232 \r
233                         peekedPacket = ReadPacket();\r
234                         return peekedPacket;\r
235                 }\r
236 \r
237                 /// <summary>\r
238                 /// ReadBuffer continuously loops until it has read the entire\r
239                 /// requested data\r
240                 /// </summary>\r
241                 /// <param name="buf">Buffer to read data into</param>\r
242                 /// <param name="offset">Offset to place the data</param>\r
243                 /// <param name="length">Number of bytes to read</param>\r
244                 private void ReadBuffer( byte[] buf, int offset, int length )\r
245                 {\r
246                         while (length > 0)\r
247                         {\r
248                                 int amountRead = stream.Read( buf, offset, length );\r
249                                 if (amountRead == 0)\r
250                                         throw new MySqlException("Unexpected end of data encountered");\r
251                                 length -= amountRead;\r
252                                 offset += amountRead;\r
253                         }\r
254                 }\r
255 \r
256                 private Packet ReadPacketFromServer()\r
257                 {\r
258                         int len = stream.ReadByte() + (stream.ReadByte() << 8) +\r
259                                 (stream.ReadByte() << 16);\r
260                         byte seq = (byte)stream.ReadByte();\r
261                         byte[] buf = new byte[ len ];\r
262                         ReadBuffer( buf, 0, len );\r
263 \r
264                         if (seq != packetSeq) \r
265                                 throw new MySqlException("Unknown transmission status: sequence out of order");\r
266                         packetSeq++;\r
267 \r
268                         Packet p = CreatePacket(buf);\r
269                         p.Encoding = this.Encoding;\r
270                         if (p.Length == MAX_PACKET_SIZE && serverVersion.isAtLeast(4,0,0)) \r
271                                 p.Append( ReadPacketFromServer() );\r
272                         return p;\r
273                 }\r
274 \r
275                 /// <summary>\r
276                 /// Reads a single packet off the stream\r
277                 /// </summary>\r
278                 /// <returns></returns>\r
279                 public Packet ReadPacket()\r
280                 {\r
281                         // if we have peeked at a packet, then return it\r
282                         if (peekedPacket != null)\r
283                         {\r
284                                 Packet packet = peekedPacket;\r
285                                 peekedPacket = null;\r
286                                 return packet;\r
287                         }\r
288 \r
289                         Packet p = ReadPacketFromServer();\r
290 \r
291                         // if this is an error packet, then throw the exception\r
292                         if (p[0] == 0xff)\r
293                         {\r
294                                 p.ReadByte();\r
295                                 int errorCode = (int)p.ReadInteger(2);\r
296                                 string msg = p.ReadString();\r
297                                 throw new MySqlException( msg, errorCode );\r
298                         }\r
299                         \r
300                         return p;\r
301                 }\r
302 \r
303                 protected MemoryStream CompressBuffer(byte[] buf, int index, int length)\r
304                 {\r
305 \r
306                         if (length < MIN_COMPRESS_LENGTH) return null;\r
307 \r
308                         MemoryStream ms = new MemoryStream(buf.Length);\r
309                         DeflaterOutputStream dos = new DeflaterOutputStream(ms);\r
310 \r
311                         dos.WriteByte( (byte)(length & 0xff ));\r
312                         dos.WriteByte( (byte)((length >> 8) & 0xff ));\r
313                         dos.WriteByte( (byte)((length >> 16) & 0xff ));\r
314                         dos.WriteByte( 0 );\r
315 \r
316                         dos.Write( buf, index, length );\r
317                         dos.Finish();\r
318                         if (ms.Length > length+4) return null;\r
319                         return ms;\r
320                 }\r
321 \r
322                 private void WriteInteger( int v, int numbytes )\r
323                 {\r
324                         int val = v;\r
325 \r
326                         if (numbytes < 1 || numbytes > 4) \r
327                                 throw new ArgumentOutOfRangeException("Wrong byte count for WriteInteger");\r
328 \r
329                         for (int x=0; x < numbytes; x++)\r
330                         {\r
331                                 writer.WriteByte( (byte)(val&0xff) );\r
332                                 val >>= 8;\r
333                         }\r
334                 }\r
335 \r
336                 /// <summary>\r
337                 /// Send a buffer to the server in a compressed form\r
338                 /// </summary>\r
339                 /// <param name="buf">Byte buffer to send</param>\r
340                 /// <param name="index">Location in buffer to start sending</param>\r
341                 /// <param name="length">Amount of data to send</param>\r
342                 protected void SendCompressedBuffer(byte[] buf, int index, int length)\r
343                 {\r
344                         MemoryStream compressed_bytes = CompressBuffer(buf, index, length);\r
345                         int comp_len = compressed_bytes == null ? length+HEADER_LEN : (int)compressed_bytes.Length;\r
346                         int ucomp_len = compressed_bytes == null ? 0 : length+HEADER_LEN;\r
347 \r
348                         WriteInteger( comp_len, 3 );\r
349                         writer.WriteByte( packetSeq++ );\r
350                         WriteInteger( ucomp_len, 3 );\r
351                         if (compressed_bytes != null)\r
352                                 writer.Write( compressed_bytes.GetBuffer(), 0, (int)compressed_bytes.Length );\r
353                         else \r
354                         {\r
355                                 WriteInteger( length, 3 );      \r
356                                 writer.WriteByte( 0 );\r
357                                 writer.Write( buf, index, length );\r
358                         }\r
359                         stream.Flush();\r
360                 }\r
361 \r
362                 protected void SendBuffer( byte[] buf, int offset, int length )\r
363                 {\r
364                         while (length > 0)\r
365                         {\r
366                                 int amount = Math.Min( 1024, length );\r
367                                 writer.Write( buf, offset, amount );\r
368                                 writer.Flush();\r
369                                 offset += amount;\r
370                                 length -= amount;\r
371                         }\r
372                 }\r
373 \r
374                 /// <summary>\r
375                 /// Send a single packet to the server.\r
376                 /// </summary>\r
377                 /// <param name="packet">Packet to send to the server</param>\r
378                 /// <remarks>This method will send a single packet to the server\r
379                 /// possibly breaking the packet up into smaller packets that are\r
380                 /// smaller than max_allowed_packet.  This method will always send at\r
381                 /// least one packet to the server</remarks>\r
382         protected void SendPacket(Packet packet)\r
383                 {\r
384                         byte[]  buf = packet.GetBuffer();\r
385                         int             len = packet.Length;\r
386                         int             index = 0;\r
387                         bool    oneSent = false;\r
388 \r
389                         // make sure we are not trying to send too much\r
390                         if (packet.Length > maxPacketSize && maxPacketSize > 0)\r
391                                 throw new MySqlException("Packet size too large.  This MySQL server cannot accept rows larger than " + maxPacketSize + " bytes.");\r
392 \r
393                         try \r
394                         {\r
395                                 while (len > 0 || ! oneSent) \r
396                                 {\r
397                                         int lenToSend = Math.Min( len, MAX_PACKET_SIZE );\r
398 \r
399                                         // send the data\r
400                                         if (useCompression)\r
401                                                 SendCompressedBuffer( buf, index, lenToSend );\r
402                                         else \r
403                                         {\r
404                                                 WriteInteger( lenToSend, 3 );\r
405                                                 writer.WriteByte( packetSeq++ );\r
406                                                 writer.Write( buf, index, lenToSend );\r
407                                                 writer.Flush();\r
408                                         }\r
409 \r
410                                         len -= lenToSend;\r
411                                         index += lenToSend;\r
412                                         oneSent = true;\r
413                                 }\r
414                                 writer.Flush();\r
415                         }\r
416                         catch (Exception ex)\r
417                         {\r
418                                 Console.WriteLine( ex.Message );\r
419                         }\r
420                 }\r
421 \r
422 \r
423                 public void Close() \r
424                 {\r
425                         if (stream != null)\r
426                                 stream.Close();\r
427                 }\r
428 \r
429 \r
430                 /// <summary>\r
431                 /// Sends the specified command to the database\r
432                 /// </summary>\r
433                 /// <param name="command">Command to execute</param>\r
434                 /// <param name="text">Text attribute of command</param>\r
435                 /// <returns>Result packet returned from database server</returns>\r
436                 public void Send( DBCmd command, String text ) \r
437                 {\r
438                         CommandResult result = Send( command, this.Encoding.GetBytes( text ) );\r
439                         if (result.IsResultSet)\r
440                                 throw new MySqlException("SendCommand failed for command " + text );\r
441                 }\r
442 \r
443                 public CommandResult Send( DBCmd cmd, byte[] bytes )\r
444                 {\r
445 //                      string s = Encoding.GetString( bytes );\r
446 \r
447                         Packet packet = CreatePacket(null);\r
448                         packetSeq = 0;\r
449                         packet.WriteByte( (byte)cmd );\r
450                         if (bytes != null)\r
451                                 packet.Write( bytes, 0, bytes.Length );\r
452 \r
453                         SendPacket( packet );\r
454                         packet = ReadPacket();\r
455 \r
456                         // first check to see if this is a LOAD DATA LOCAL callback\r
457                         // if so, send the file and then read the results\r
458                         long fieldcount = packet.ReadLenInteger();\r
459                         if (fieldcount == Packet.NULL_LEN)\r
460                         {\r
461                                 string filename = packet.ReadString();\r
462                                 SendFileToServer( filename );\r
463                                 packet = ReadPacket();\r
464                         }\r
465                         else\r
466                                 packet.Position = 0;\r
467 \r
468                         return new CommandResult(packet, this);\r
469                 }\r
470 \r
471                 /// <summary>\r
472                 /// Sends the specified file to the server. \r
473                 /// This supports the LOAD DATA LOCAL INFILE\r
474                 /// </summary>\r
475                 /// <param name="filename"></param>\r
476                 private void SendFileToServer( string filename )\r
477                 {\r
478                         Packet          p = CreatePacket(null);\r
479                         byte[]          buffer = new byte[4092];\r
480                         FileStream      fs = null;\r
481                         try \r
482                         {\r
483                                 fs = new FileStream( filename, FileMode.Open );\r
484                                 int count = fs.Read( buffer, 0, buffer.Length );\r
485                                 while (count != 0) \r
486                                 {\r
487                                         if ((p.Length + count) > MAX_PACKET_SIZE)\r
488                                         {\r
489                                                 SendPacket( p );\r
490                                                 p.Clear();\r
491                                         }\r
492                                         p.Write( buffer, 0, count );\r
493                                         count = fs.Read( buffer, 0, buffer.Length );\r
494                                 }\r
495                                 fs.Close();\r
496 \r
497                                 // send any remaining data\r
498                                 if (p.Length > 0) \r
499                                 {\r
500                                         SendPacket(p);\r
501                                         p.Clear();\r
502                                 }\r
503                         }\r
504                         catch (Exception ex)\r
505                         {\r
506                                 throw new MySqlException("Error during LOAD DATA LOCAL INFILE", ex);\r
507                         }\r
508                         finally \r
509                         {\r
510                                 if (fs != null)\r
511                                         fs.Close();\r
512                                 // empty packet signals end of file\r
513                                 p.Clear();\r
514                                 SendPacket(p);\r
515                         }\r
516                 }\r
517 \r
518                 #region PasswordStuff\r
519                 private static double rand(ref long seed1, ref long seed2)\r
520                 {\r
521                         seed1 = (seed1 * 3) + seed2;\r
522                         seed1 %= 0x3fffffff;\r
523                         seed2 = (seed1 + seed2 + 33) % 0x3fffffff;\r
524                         return (seed1 / (double)0x3fffffff);\r
525                 }\r
526 \r
527                 /// <summary>\r
528                 /// Encrypts a password using the MySql encryption scheme\r
529                 /// </summary>\r
530                 /// <param name="password">The password to encrypt</param>\r
531                 /// <param name="message">The encryption seed the server gave us</param>\r
532                 /// <param name="new_ver">Indicates if we should use the old or new encryption scheme</param>\r
533                 /// <returns></returns>\r
534                 public static String EncryptPassword(String password, String message, bool new_ver)\r
535                 {\r
536                         if (password == null || password.Length == 0)\r
537                                 return password;\r
538 \r
539                         long[] hash_message = Hash(message);\r
540                         long[] hash_pass = Hash(password);\r
541 \r
542                         long seed1 = (hash_message[0]^hash_pass[0]) % 0x3fffffff;\r
543                         long seed2 = (hash_message[1]^hash_pass[1]) % 0x3fffffff;\r
544 \r
545                         char[] scrambled = new char[message.Length];\r
546                         for (int x=0; x < message.Length; x++) \r
547                         {\r
548                                 double r = rand(ref seed1, ref seed2);\r
549                                 scrambled[x] = (char)(Math.Floor(r*31) + 64);\r
550                         }\r
551 \r
552                         if (new_ver)\r
553                         {                                               /* Make it harder to break */\r
554                                 char extra = (char)Math.Floor( rand(ref seed1, ref seed2) * 31 );\r
555                                 for (int x=0; x < scrambled.Length; x++)\r
556                                         scrambled[x] ^= extra;\r
557                         }\r
558 \r
559                         return new string(scrambled);\r
560                 }\r
561 \r
562                 /// <summary>\r
563                 /// \r
564                 /// </summary>\r
565                 /// <param name="P"></param>\r
566                 /// <returns></returns>\r
567                 static long[] Hash(String P) \r
568                 {\r
569                         long val1 = 1345345333;\r
570                         long val2 = 0x12345671;\r
571                         long inc  = 7;\r
572 \r
573                         for (int i=0; i < P.Length; i++) \r
574                         {\r
575                                 if (P[i] == ' ' || P[i] == '\t') continue;\r
576                                 long temp = (long)(0xff & P[i]);\r
577                                 val1 ^= (((val1 & 63)+inc)*temp) + (val1 << 8);\r
578                                 val2 += (val2 << 8) ^ val1;\r
579                                 inc += temp;\r
580                         }\r
581 \r
582                         long[] hash = new long[2];\r
583                         hash[0] = val1 & 0x7fffffff;\r
584                         hash[1] = val2 & 0x7fffffff;\r
585                         return hash;\r
586                 }\r
587                 #endregion\r
588         }\r
589 }\r