Resyncing ByteFX.Data with latest from ByteFX
[mono.git] / mcs / class / ByteFX.Data / mysqlclient / Driver.cs
1 // ByteFX.Data data access components for .Net
2 // Copyright (C) 2002-2003  ByteFX, Inc.
3 //
4 // This library is free software; you can redistribute it and/or
5 // modify it under the terms of the GNU Lesser General Public
6 // License as published by the Free Software Foundation; either
7 // version 2.1 of the License, or (at your option) any later version.
8 // 
9 // This library is distributed in the hope that it will be useful,
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 // Lesser General Public License for more details.
13 // 
14 // You should have received a copy of the GNU Lesser General Public
15 // License along with this library; if not, write to the Free Software
16 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
18 using System;
19 using System.Net;
20 using System.Net.Sockets;
21 using System.IO;
22 using ICSharpCode.SharpZipLib.Zip.Compression;
23 using System.Security.Cryptography;
24 using ByteFX.Data.Common;
25 using System.Text;
26
27 namespace ByteFX.Data.MySqlClient
28 {
29         /// <summary>
30         /// Summary description for Driver.
31         /// </summary>
32         internal class Driver
33         {
34                 protected const int HEADER_LEN = 4;
35                 protected const int MIN_COMPRESS_LENGTH = 50;
36
37                 protected MySqlStream           stream;
38                 protected Encoding                      encoding;
39                 protected byte                          packetSeq;
40                 protected int                           timeOut;
41                 protected long                          maxPacketSize;
42                 protected Packet                        peekedPacket = null;
43                 protected ByteFX.Data.Common.Version    serverVersion;
44                 protected bool                          isOpen;
45
46                 int             protocol;
47                 uint    threadID;
48                 String  encryptionSeed;
49                 int             serverCaps;
50                 bool    useCompression = false;
51
52
53                 public Driver()
54                 {
55                         packetSeq = 0;
56                         encoding = System.Text.Encoding.Default;
57                         isOpen = false;
58                 }
59
60                 #region Properties
61                 public bool IsDead
62                 {
63                         get 
64                         { 
65                                 return stream.IsClosed;
66                         }
67                 }
68                 #endregion
69
70                 public Encoding Encoding 
71                 {
72                         get { return encoding; }
73                         set { encoding = value; }
74                 }
75
76                 public long MaxPacketSize 
77                 {
78                         get { return maxPacketSize; }
79                         set { maxPacketSize = value; }
80                 }
81
82                 /// <summary>
83                 /// 
84                 /// </summary>
85                 /// <param name="host"></param>
86                 /// <param name="port"></param>
87                 /// <param name="userid"></param>
88                 /// <param name="password"></param>
89                 public void Open( String host, int port, String userid, String password, 
90                         bool UseCompression, int connectTimeout ) 
91                 {
92                         timeOut = connectTimeout;
93                         stream = new MySqlStream( host, port, timeOut );
94
95                         Packet packet = ReadPacket();
96
97                         // read off the protocol version
98                         protocol = packet.ReadByte();
99                         serverVersion = ByteFX.Data.Common.Version.Parse( packet.ReadString() );
100                         threadID = packet.ReadInteger(4);
101                         encryptionSeed = packet.ReadString();
102
103                         // read in Server capabilities if they are provided
104                         serverCaps = 0;
105                         if (packet.CanRead)
106                                 serverCaps = (int)packet.ReadInteger(2);
107
108                         Authenticate( userid, password, UseCompression );
109                         isOpen = true;
110                 }
111
112                 /// <summary>
113                 /// 
114                 /// </summary>
115                 /// <param name="userid"></param>
116                 /// <param name="password"></param>
117                 private void Authenticate( String userid, String password, bool UseCompression )
118                 {
119                         ClientParam clientParam = ClientParam.CLIENT_FOUND_ROWS | ClientParam.CLIENT_LONG_FLAG;
120
121                         if ((serverCaps & (int)ClientParam.CLIENT_COMPRESS) != 0 && UseCompression)
122                         {
123                                 clientParam |= ClientParam.CLIENT_COMPRESS;
124                         }
125
126                         clientParam |= ClientParam.CLIENT_LONG_PASSWORD;
127                         clientParam |= ClientParam.CLIENT_LOCAL_FILES;
128 //                      if (serverVersion.isAtLeast(4,1,0))
129 //                              clientParam |= ClientParam.CLIENT_PROTOCOL_41;
130                         if ( (serverCaps & (int)ClientParam.CLIENT_SECURE_CONNECTION ) != 0 && password.Length > 0 )
131                                 clientParam |= ClientParam.CLIENT_SECURE_CONNECTION;
132
133                         int packetLength = userid.Length + 16 + 6 + 4;  // Passwords can be 16 chars long
134
135                         Packet packet = new Packet();// packetLength );
136
137                         if ((clientParam & ClientParam.CLIENT_PROTOCOL_41) != 0)
138                         {
139                                 packet.WriteInteger( (int)clientParam, 4 );
140                                 packet.WriteInteger( (256*256*256)-1, 4 );
141                         }
142                         else
143                         {
144                                 packet.WriteInteger( (int)clientParam, 2 );
145                                 packet.WriteInteger( 255*255*255, 3 );
146                         }
147
148                         packet.WriteString( userid, encoding  );
149                         if ( (clientParam & ClientParam.CLIENT_SECURE_CONNECTION ) != 0 )
150                         {
151                                 // use the new authentication system
152                                 AuthenticateSecurely( packet, password );
153                         }
154                         else
155                         {
156                                 // use old authentication system
157                                 packet.WriteString( EncryptPassword(password, encryptionSeed, protocol > 9), encoding );
158                                 SendPacket(packet);
159                         }
160
161                         packet = ReadPacket();
162                         if ((clientParam & ClientParam.CLIENT_COMPRESS) != 0)
163                                 useCompression = true;
164                 }
165
166                 /// <summary>
167                 /// AuthenticateSecurity implements the new 4.1 authentication scheme
168                 /// </summary>
169                 /// <param name="password"></param>
170                 private void AuthenticateSecurely( Packet packet, string password )
171                 {
172                         packet.WriteString("xxxxxxxx", encoding );
173                         SendPacket(packet);
174
175                         packet = ReadPacket();
176
177                         // compute pass1 hash
178                         string newPass = password.Replace(" ","").Replace("\t","");
179                         SHA1 sha = new SHA1CryptoServiceProvider(); 
180                         byte[] firstPassBytes = sha.ComputeHash( System.Text.Encoding.Default.GetBytes(newPass));
181
182                         byte[] salt = packet.GetBytes();
183                         byte[] input = new byte[ firstPassBytes.Length + 4 ];
184                         salt.CopyTo( input, 0 );
185                         firstPassBytes.CopyTo( input, 4 );
186                         byte[] outPass = new byte[100];
187                         byte[] secondPassBytes = sha.ComputeHash( input );
188
189                         byte[] cryptSalt = new byte[20];
190                         Security.ArrayCrypt( salt, 4, cryptSalt, 0, secondPassBytes, 20 );
191
192                         Security.ArrayCrypt( cryptSalt, 0, firstPassBytes, 0, firstPassBytes, 20 );
193
194                         // send the packet
195                         packet = new Packet();
196                         packet.WriteBytes( firstPassBytes, 0, 20 );
197                         SendPacket(packet);
198                 }
199
200
201                 /// <summary>
202                 /// 
203                 /// </summary>
204                 private Packet ReadRawPacket()
205                 {
206                         int packetLength = stream.ReadInt24();
207                         int unCompressedLen = 0;
208
209                         // read the packet sequence and make sure it makes sense
210                         byte seq = (byte)stream.ReadByte();
211                         if (seq != packetSeq) 
212                                 throw new MySqlException("Unknown transmission status: sequence out of order");
213         
214                         if (useCompression) 
215                                 unCompressedLen = stream.ReadInt24();
216
217                         byte[] buffer;
218                         if (useCompression && unCompressedLen > 0)
219                         {
220                                 byte[] compressed_buffer = new Byte[packetLength];
221                                 buffer = new Byte[unCompressedLen];
222
223                                 // read in the compressed data
224                                 stream.Read( compressed_buffer, 0, packetLength );
225
226                                 // inflate it
227                                 Inflater i = new Inflater();
228                                 i.SetInput( compressed_buffer );
229                                 i.Inflate( buffer );
230                         }
231                         else 
232                         {
233                                 buffer = new Byte[packetLength];
234                                 stream.Read( buffer, 0, packetLength);
235                         }
236
237                         packetSeq++;
238                         Packet packet = new Packet( buffer );
239                         packet.Encoding = encoding;
240                         return packet;
241                 }
242
243                 /// <summary>
244                 /// 
245                 /// </summary>
246                 public void SendFileToServer()
247                 {
248                 }
249
250
251                 /// <summary>
252                 /// 
253                 /// </summary>
254                 /// <returns></returns>
255                 public Packet PeekPacket()
256                 {
257                         // we can peek the same packet more than once
258                         if (peekedPacket != null)
259                                 return peekedPacket;
260
261                         peekedPacket = ReadPacket();
262                         return peekedPacket;
263                 }
264
265                 /// <summary>
266                 /// 
267                 /// </summary>
268                 /// <returns></returns>
269                 public Packet ReadPacket()
270                 {
271                         // if we have a peeked packet, return it now
272                         if (peekedPacket != null) 
273                         {
274                                 Packet p = peekedPacket;
275                                 peekedPacket = null;
276                                 return p;
277                         }
278
279                         Packet packet = ReadRawPacket();
280
281                         if (packet.Type == PacketType.Error)
282                         {
283                                 int errorCode = (int)packet.ReadInteger(2);
284                                 string msg = packet.ReadString();
285                                 throw new MySqlException( msg, errorCode );
286                         }
287                         else 
288                                 packet.Position = 0;
289
290                         return packet;
291                 }
292
293                 /// <summary>
294                 /// 
295                 /// </summary>
296                 /// <param name="packet"></param>
297                 private Packet LoadSchemaIntoPacket( Packet packet, int count )
298                 {
299                         for (int i=0; i < count; i++) 
300                         {
301                                 Packet colPacket = ReadRawPacket();
302                                 packet.AppendPacket( colPacket );
303                         }
304                         Packet lastPacket = ReadRawPacket();
305                         if (lastPacket.Type != PacketType.Last)
306                                 throw new MySqlException("Last packet not received when expected");
307
308                         packet.Type = PacketType.ResultSchema;
309                         packet.Position = 0;
310                         return packet;
311                 }
312
313                 /// <summary>
314                 /// 
315                 /// </summary>
316                 /// <returns></returns>
317 /*              protected byte[] CompressPacket(Packet packet)
318                 {
319                         // compress the entire packet except the length
320
321                         // make sure we are using a packet prep'ed for compression
322                         // and that our packet is large enough to warrant compression
323                         // re: my_compress.c from mysql src
324                         int offset = HEADER_LEN + COMPRESS_HEADER_LEN;
325                         int original_len = (int)(_packet.Length - offset);
326                         if (original_len < MIN_COMPRESS_LEN) return 0;
327
328                         byte[] packetData = _packet.ToArray();
329
330                         byte[] output = new Byte[ original_len * 2 ];
331                         Deflater d = new Deflater();
332                         d.SetInput( packetData, offset, original_len );
333                         d.Finish();
334                         int comp_len = d.Deflate( output, offset, output.Length - offset  );
335
336                         if (comp_len > original_len) return 0;
337                         _packet = new MemoryStream( output, 0, comp_len + offset );
338                         return (int)comp_len;
339                 }
340 */
341                 protected byte[] CompressPacket(Packet packet)
342                 {
343                         if (packet.Length < MIN_COMPRESS_LENGTH) return null;
344
345                         byte[] compressed_buffer = new byte[packet.Length * 2];
346                         Deflater deflater = new Deflater();
347                         deflater.SetInput( packet.GetBytes(), 0, packet.Length );
348                         deflater.Finish();
349                         int comp_len = deflater.Deflate( compressed_buffer, 0, compressed_buffer.Length );
350                         if (comp_len > packet.Length) return null;
351                         return compressed_buffer;
352                 }
353
354                 protected void SendPacket(Packet packet)
355                 {
356                         Packet header = null;
357                         byte[] buffer = null;
358
359                         if (useCompression)
360                         {
361                                 byte[] compressed_bytes = CompressPacket(packet);
362                                 header = new Packet();
363                                 
364                                 // if we succeeded in compressing
365                                 if (compressed_bytes != null) 
366                                 {
367                                         header.WriteInteger( compressed_bytes.Length, 3 );
368                                         header.WriteByte( packetSeq );
369                                         header.WriteInteger( packet.Length + HEADER_LEN, 3 );
370                                         buffer = compressed_bytes;
371                                 }
372                                 else
373                                 {
374                                         header.WriteInteger( packet.Length + HEADER_LEN, 3 );
375                                         header.WriteByte( packetSeq );
376                                         header.WriteInteger( 0, 3 );
377                                         buffer = packet.GetBytes();
378                                 }
379                                 // now write the internal header
380                                 header.WriteInteger( packet.Length, 3 );
381                                 header.WriteByte( 0 );
382                         }
383                         else 
384                         {
385                                 header = new Packet();
386                                 header.WriteInteger( packet.Length, 3 );
387                                 header.WriteByte( packetSeq );
388                                 buffer = packet.GetBytes();
389                         }
390                         packetSeq++;
391
392                         // send the data to eth server
393                         stream.Write( header.GetBytes(), 0, header.Length );
394                         stream.Write( buffer, 0, buffer.Length );
395                         stream.Flush();
396                 }
397
398
399                 public void Close() 
400                 {
401                         stream.Close();
402                 }
403
404
405                 /// <summary>
406                 /// Sends the specified command to the database
407                 /// </summary>
408                 /// <param name="command">Command to execute</param>
409                 /// <param name="text">Text attribute of command</param>
410                 /// <returns>Result packet returned from database server</returns>
411                 public void SendCommand( DBCmd command, String text ) 
412                 {
413                         Packet packet = new Packet();
414                         packetSeq = 0;
415                         packet.WriteByte( (byte)command );
416                         packet.WriteStringNoNull( text, encoding );
417                         SendPacket(packet);
418                         
419                         packet = ReadPacket();
420                         if (packet.Type != PacketType.UpdateOrOk)
421                                 throw new MySqlException("SendCommand failed for command " + text );
422                 }
423
424                 /// <summary>
425                 /// SendQuery sends a byte array of SQL to the server
426                 /// </summary>
427                 /// <param name="sql"></param>
428                 /// <returns>A packet containing the bytes returned by the server</returns>
429                 public Packet SendQuery( byte[] sql )
430                 {
431                         Packet packet = new Packet();
432                         packetSeq = 0;
433                         packet.WriteByte( (byte)DBCmd.QUERY );
434                         packet.WriteBytes( sql, 0, sql.Length );
435
436                         SendPacket( packet );
437                         return ReadPacket();
438                 }
439
440                 public Packet SendSql( string sql )
441                 {
442                         byte[] bytes = encoding.GetBytes(sql);
443
444                         Packet packet = new Packet();
445                         packetSeq = 0;
446                         packet.WriteByte( (byte)DBCmd.QUERY );
447                         packet.WriteBytes( bytes, 0, bytes.Length );
448
449                         SendPacket( packet );
450                         packet = ReadPacket();
451
452                         switch (packet.Type)
453                         {
454                                 case PacketType.LoadDataLocal:
455                                         SendFileToServer();
456                                         return null;
457
458                                 case PacketType.Other:
459                                         packet.Position = 0;
460                                         int count = (int)packet.ReadLenInteger();
461                                         if (count > 0) 
462                                                 return LoadSchemaIntoPacket( packet, count );
463                                         else
464                                                 return packet;
465                         }
466
467                         return packet;
468                 }
469
470                 #region PasswordStuff
471                 private static double rand(ref long seed1, ref long seed2)
472                 {
473                         seed1 = (seed1 * 3) + seed2;
474                         seed1 %= 0x3fffffff;
475                         seed2 = (seed1 + seed2 + 33) % 0x3fffffff;
476                         return (seed1 / (double)0x3fffffff);
477                 }
478
479                 /// <summary>
480                 /// 
481                 /// </summary>
482                 /// <param name="password"></param>
483                 /// <param name="seed"></param>
484                 /// <returns></returns>
485                 public static String EncryptPassword(String password, String message, bool new_ver)
486                 {
487                         if (password == null || password.Length == 0)
488                                 return password;
489
490                         long[] hash_message = Hash(message);
491                         long[] hash_pass = Hash(password);
492
493                         long seed1 = (hash_message[0]^hash_pass[0]) % 0x3fffffff;
494                         long seed2 = (hash_message[1]^hash_pass[1]) % 0x3fffffff;
495
496                         char[] scrambled = new char[message.Length];
497                         for (int x=0; x < message.Length; x++) 
498                         {
499                                 double r = rand(ref seed1, ref seed2);
500                                 scrambled[x] = (char)(Math.Floor(r*31) + 64);
501                         }
502
503                         if (new_ver)
504                         {                                               /* Make it harder to break */
505                                 char extra = (char)Math.Floor( rand(ref seed1, ref seed2) * 31 );
506                                 for (int x=0; x < scrambled.Length; x++)
507                                         scrambled[x] ^= extra;
508                         }
509
510                         return new string(scrambled);
511                 }
512
513                 /// <summary>
514                 /// 
515                 /// </summary>
516                 /// <param name="P"></param>
517                 /// <returns></returns>
518                 static long[] Hash(String P) 
519                 {
520                         long val1 = 1345345333;
521                         long val2 = 0x12345671;
522                         long inc  = 7;
523
524                         for (int i=0; i < P.Length; i++) 
525                         {
526                                 if (P[i] == ' ' || P[i] == '\t') continue;
527                                 long temp = (long)(0xff & P[i]);
528                                 val1 ^= (((val1 & 63)+inc)*temp) + (val1 << 8);
529                                 val2 += (val2 << 8) ^ val1;
530                                 inc += temp;
531                         }
532
533                         long[] hash = new long[2];
534                         hash[0] = val1 & 0x7fffffff;
535                         hash[1] = val2 & 0x7fffffff;
536                         return hash;
537                 }
538                 #endregion
539         }
540 }