New test.
[mono.git] / mcs / class / Mono.Data.SqliteClient / Mono.Data.SqliteClient / SqliteCommand.cs
index c0fc4ae776f0d7ff936af0835349af85f54f4055..e2dd694d2e12d71c78a4514864fa0df35bec9167 100644 (file)
@@ -9,6 +9,7 @@
 //             Chris Turchin <chris@turchin.net>
 //             Jeroen Zwartepoorte <jeroen@xs4all.nl>
 //             Thomas Zoechling <thomas.zoechling@gmx.at>
+//             Joshua Tauberer <tauberer@for.net>
 //
 // Copyright (C) 2002  Vladimir Vukicevic
 //
@@ -39,6 +40,7 @@ using System.Runtime.InteropServices;
 using System.Text.RegularExpressions;
 using System.Data;
 using System.Diagnostics; 
+using Group = System.Text.RegularExpressions.Group;
 
 namespace Mono.Data.SqliteClient 
 {
@@ -55,7 +57,6 @@ namespace Mono.Data.SqliteClient
                private UpdateRowSource upd_row_source;
                private SqliteParameterCollection sql_params;
                private bool prepared = false;
-               private ArrayList pStmts;
 
                #endregion
 
@@ -95,7 +96,7 @@ namespace Mono.Data.SqliteClient
                public string CommandText 
                {
                        get { return sql; }
-                       set { sql = value; }
+                       set { sql = value; prepared = false; }
                }
                
                public int CommandTimeout
@@ -170,27 +171,171 @@ namespace Mono.Data.SqliteClient
                                return Sqlite.sqlite_changes(parent_conn.Handle);
                }
                
-               private string ReplaceParams(Match m)
+               private void BindParameters3 (IntPtr pStmt)
                {
-                       string input = m.Value;                                                                                                                
-                       if (m.Groups["param"].Success)
+                       if (sql_params == null) return;
+                       if (sql_params.Count == 0) return;
+                       
+                       int pcount = Sqlite.sqlite3_bind_parameter_count (pStmt);
+
+                       for (int i = 1; i <= pcount; i++) 
                        {
-                               Group g = m.Groups["param"];
-                               string find = g.Value;
-                               //FIXME: sqlite works internally only with strings, so this assumtion is mostly legit, but what about date formatting, etc?
-                               //Need to fix SqlLiteDataReader first to acurately describe the tables
-                               SqliteParameter sqlp = Parameters[find];
-                               string replace = Convert.ToString(sqlp.Value);
-                               if(sqlp.DbType == DbType.String)
-                               {
-                                       replace =  "\"" + replace + "\"";
+                               String name = Sqlite.HeapToString (Sqlite.sqlite3_bind_parameter_name (pStmt, i), Encoding.UTF8);
+
+                               SqliteParameter param = null;
+                               if (name != null)
+                                       param = sql_params[name];
+                               else
+                                       param = sql_params[i-1];
+                               
+                               if (param.Value == null) {
+                                       Sqlite.sqlite3_bind_null (pStmt, i);
+                                       continue;
                                }
+                                       
+                               Type ptype = param.Value.GetType ();
+                               if (ptype.IsEnum)
+                                       ptype = Enum.GetUnderlyingType (ptype);
                                
-                               input = Regex.Replace(input,find,replace);
-                               return input;
+                               SqliteError err;
+                               
+                               if (ptype.Equals (typeof (String))) 
+                               {
+                                       String s = (String)param.Value;
+                                       err = Sqlite.sqlite3_bind_text16 (pStmt, i, s, -1, (IntPtr)(-1));
+                               } 
+                               else if (ptype.Equals (typeof (DBNull))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_null (pStmt, i);
+                               }
+                               else if (ptype.Equals (typeof (Boolean))) 
+                               {
+                                       bool b = (bool)param.Value;
+                                       err = Sqlite.sqlite3_bind_int (pStmt, i, b ? 1 : 0);
+                               } else if (ptype.Equals (typeof (Byte))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int (pStmt, i, (Byte)param.Value);
+                               }
+                               else if (ptype.Equals (typeof (Char))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int (pStmt, i, (Char)param.Value);
+                               } 
+                               else if (ptype.IsEnum) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int (pStmt, i, (Int32)param.Value);
+                               }
+                               else if (ptype.Equals (typeof (Int16))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int (pStmt, i, (Int16)param.Value);
+                               } 
+                               else if (ptype.Equals (typeof (Int32))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int (pStmt, i, (Int32)param.Value);
+                               }
+                               else if (ptype.Equals (typeof (SByte))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int (pStmt, i, (SByte)param.Value);
+                               } 
+                               else if (ptype.Equals (typeof (UInt16))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int (pStmt, i, (UInt16)param.Value);
+                               }
+                               else if (ptype.Equals (typeof (DateTime))) 
+                               {
+                                       DateTime dt = (DateTime)param.Value;
+                                       err = Sqlite.sqlite3_bind_int64 (pStmt, i, dt.ToFileTime ());
+                               } 
+                               else if (ptype.Equals (typeof (Double))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_double (pStmt, i, (Double)param.Value);
+                               }
+                               else if (ptype.Equals (typeof (Single))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_double (pStmt, i, (Single)param.Value);
+                               } 
+                               else if (ptype.Equals (typeof (UInt32))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int64 (pStmt, i, (UInt32)param.Value);
+                               }
+                               else if (ptype.Equals (typeof (Int64))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_int64 (pStmt, i, (Int64)param.Value);
+                               } 
+                               else if (ptype.Equals (typeof (Byte[]))) 
+                               {
+                                       err = Sqlite.sqlite3_bind_blob (pStmt, i, (Byte[])param.Value, ((Byte[])param.Value).Length, (IntPtr)(-1));
+                               } 
+                               else 
+                               {
+                                       throw new ApplicationException("Unkown Parameter Type");
+                               }
+                               if (err != SqliteError.OK) 
+                               {
+                                       throw new ApplicationException ("Sqlite error in bind " + err);
+                               }
+                       }
+               }
+
+               private void GetNextStatement (IntPtr pzStart, out IntPtr pzTail, out IntPtr pStmt)
+               {
+                       if (parent_conn.Version == 3)
+                       {
+                               SqliteError err = Sqlite.sqlite3_prepare16 (parent_conn.Handle, pzStart, -1, out pStmt, out pzTail);
+                               if (err != SqliteError.OK)
+                                       throw new SqliteSyntaxException (GetError3());
                        }
                        else
-                       return m.Value;
+                       {
+                               IntPtr errMsg;
+                               SqliteError err = Sqlite.sqlite_compile (parent_conn.Handle, pzStart, out pzTail, out pStmt, out errMsg);
+                               
+                               if (err != SqliteError.OK) 
+                               {
+                                       string msg = "unknown error";
+                                       if (errMsg != IntPtr.Zero) 
+                                       {
+                                               msg = Marshal.PtrToStringAnsi (errMsg);
+                                               Sqlite.sqliteFree (errMsg);
+                                       }
+                                       throw new SqliteSyntaxException (msg);
+                               }
+                       }
+               }
+               
+               // Executes a statement and ignores its result.
+               private void ExecuteStatement (IntPtr pStmt) {
+                       int cols;
+                       IntPtr pazValue, pazColName;
+                       ExecuteStatement (pStmt, out cols, out pazValue, out pazColName);
+               }
+
+               // Executes a statement and returns whether there is more data available.
+               internal bool ExecuteStatement (IntPtr pStmt, out int cols, out IntPtr pazValue, out IntPtr pazColName) {
+                       SqliteError err;
+                       
+                       if (parent_conn.Version == 3) 
+                       {
+                               err = Sqlite.sqlite3_step (pStmt);
+                               if (err == SqliteError.ERROR)
+                                       throw new SqliteExecutionException (GetError3());
+                               pazValue = IntPtr.Zero; pazColName = IntPtr.Zero; // not used for v=3
+                               cols = Sqlite.sqlite3_column_count (pStmt);
+                       }
+                       else 
+                       {
+                               err = Sqlite.sqlite_step (pStmt, out cols, out pazValue, out pazColName);
+                               if (err == SqliteError.ERROR)
+                                       throw new SqliteExecutionException ();
+                       }
+                       
+                       if (err == SqliteError.BUSY)
+                               throw new SqliteBusyException();
+                       
+                       if (err == SqliteError.MISUSE)
+                               throw new SqliteExecutionException();
+                               
+                       // err is either ROW or DONE.
+                       return err == SqliteError.ROW;
                }
                
                #endregion
@@ -201,154 +346,61 @@ namespace Mono.Data.SqliteClient
                {
                }
                
-               public string ProcessParameters()
+               public string BindParameters2()
                {
-                       string processedText = sql;
-
-                       //Regex looks odd perhaps, but it works - same impl. as in the firebird db provider
-                       //the named parameters are using the ADO.NET standard @-prefix but sqlite is considering ":" as a prefix for v.3...
-                       //ref: http://www.mail-archive.com/sqlite-users@sqlite.org/msg01851.html
-                       //Regex r = new Regex(@"(('[^']*?\@[^']*')*[^'@]*?)*(?<param>@\w+)+([^'@]*?('[^']*?\@[^']*'))*",RegexOptions.ExplicitCapture);
+                       string text = sql;
                        
-                       //The above statement is true for the commented regEx, but I changed it to use the :-prefix, because now (12.05.2005 sqlite3) 
-                       //sqlite is using : as Standard Parameterprefix
+                       // There used to be a crazy regular expression here, but it caused Mono
+                       // to go into an infinite loop of some sort when there were no parameters
+                       // in the SQL string.  That was too complicated anyway.
                        
-                       Regex r = new Regex(@"(('[^']*?\:[^']*')*[^':]*?)*(?<param>:\w+)+([^':]*?('[^']*?\:[^']*'))*",RegexOptions.ExplicitCapture);
-                       MatchEvaluator me = new MatchEvaluator(ReplaceParams);
-                       processedText = r.Replace(sql, me);
-                       return processedText;
+                       // Here we search for substrings of the form [:?]wwwww where w is a letter or digit
+                       // (not sure what a legitimate Sqlite3 identifier is), except those within quotes.
+                       
+                       char inquote = (char)0;
+                       int counter = 0;
+                       for (int i = 0; i < text.Length; i++) {
+                               char c = text[i];
+                               if (c == inquote) {
+                                       inquote = (char)0;
+                               } else if (inquote == (char)0 && (c == '\'' || c == '"')) {
+                                       inquote = c;
+                               } else if (inquote == (char)0 && (c == ':' || c == '?')) {
+                                       int start = i;
+                                       while (++i < text.Length && char.IsLetterOrDigit(text[i])) { } // scan to end
+                                       string name = text.Substring(start, i-start);
+                                       SqliteParameter p;
+                                       if (name.Length > 1)
+                                               p = Parameters[name];
+                                       else
+                                               p = Parameters[counter];
+                                       string value = "'" + Convert.ToString(p.Value).Replace("'", "''") + "'";
+                                       text = text.Remove(start, name.Length).Insert(start, value);
+                                       i += value.Length - name.Length - 1;
+                                       counter++;
+                               }
+                       }
+                       
+                       return text;
                }
                
                public void Prepare ()
                {
-                       pStmts = new ArrayList();
-                       string sqlcmds = sql;
+                       // There isn't much we can do here.  If a table schema
+                       // changes after preparing a statement, Sqlite bails,
+                       // so we can only compile statements right before we
+                       // want to run them.
                        
+                       if (prepared) return;           
+               
                        if (Parameters.Count > 0 && parent_conn.Version == 2)
                        {
-                               sqlcmds = ProcessParameters();
+                               sql = BindParameters2();
                        }
                        
-                       SqliteError err = SqliteError.OK;
-                       IntPtr psql = Marshal.StringToCoTaskMemAnsi(sqlcmds);
-                       IntPtr pzTail = psql;
-                       try {
-                               do { // sql may contain multiple sql commands, loop until they're all processed
-                                       IntPtr pStmt = IntPtr.Zero;
-                                       if (parent_conn.Version == 3)
-                                       {
-                                               err = Sqlite.sqlite3_prepare (parent_conn.Handle, pzTail, sql.Length, out pStmt, out pzTail);
-                                               if (err != SqliteError.OK) {
-                                                       string msg = Marshal.PtrToStringAnsi (Sqlite.sqlite3_errmsg (parent_conn.Handle));
-                                                       throw new ApplicationException (msg);
-                                               }
-                                       }
-                                       else
-                                       {
-                                               IntPtr errMsg;
-                                               err = Sqlite.sqlite_compile (parent_conn.Handle, pzTail, out pzTail, out pStmt, out errMsg);
-                                               
-                                               if (err != SqliteError.OK) 
-                                               {
-                                                       string msg = "unknown error";
-                                                       if (errMsg != IntPtr.Zero) 
-                                                       {
-                                                               msg = Marshal.PtrToStringAnsi (errMsg);
-                                                               Sqlite.sqliteFree (errMsg);
-                                                       }
-                                                       throw new ApplicationException ("Sqlite error: " + msg);
-                                               }
-                                       }
-                                               
-                                       pStmts.Add(pStmt);
-                                       
-                                       if (parent_conn.Version == 3) 
-                                       {
-                                               int pcount = Sqlite.sqlite3_bind_parameter_count (pStmt);
-                                               if (sql_params == null) pcount = 0;
-               
-                                               for (int i = 1; i <= pcount; i++) 
-                                               {
-                                                       String name = Sqlite.sqlite3_bind_parameter_name (pStmt, i);
-                                                       SqliteParameter param = sql_params[name];
-                                                       Type ptype = param.Value.GetType ();
-                                                       
-                                                       if (ptype.Equals (typeof (String))) 
-                                                       {
-                                                               String s = (String)param.Value;
-                                                               err = Sqlite.sqlite3_bind_text (pStmt, i, s, s.Length, (IntPtr)(-1));
-                                                       } 
-                                                       else if (ptype.Equals (typeof (DBNull))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_null (pStmt, i);
-                                                       }
-                                                       else if (ptype.Equals (typeof (Boolean))) 
-                                                       {
-                                                               bool b = (bool)param.Value;
-                                                               err = Sqlite.sqlite3_bind_int (pStmt, i, b ? 1 : 0);
-                                                       } else if (ptype.Equals (typeof (Byte))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_int (pStmt, i, (Byte)param.Value);
-                                                       }
-                                                       else if (ptype.Equals (typeof (Char))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_int (pStmt, i, (Char)param.Value);
-                                                       } 
-                                                       else if (ptype.Equals (typeof (Int16))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_int (pStmt, i, (Int16)param.Value);
-                                                       } 
-                                                       else if (ptype.Equals (typeof (Int32))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_int (pStmt, i, (Int32)param.Value);
-                                                       }
-                                                       else if (ptype.Equals (typeof (SByte))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_int (pStmt, i, (SByte)param.Value);
-                                                       } 
-                                                       else if (ptype.Equals (typeof (UInt16))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_int (pStmt, i, (UInt16)param.Value);
-                                                       }
-                                                       else if (ptype.Equals (typeof (DateTime))) 
-                                                       {
-                                                               DateTime dt = (DateTime)param.Value;
-                                                               err = Sqlite.sqlite3_bind_int64 (pStmt, i, dt.ToFileTime ());
-                                                       } 
-                                                       else if (ptype.Equals (typeof (Double))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_double (pStmt, i, (Double)param.Value);
-                                                       }
-                                                       else if (ptype.Equals (typeof (Single))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_double (pStmt, i, (Single)param.Value);
-                                                       } 
-                                                       else if (ptype.Equals (typeof (UInt32))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_int64 (pStmt, i, (UInt32)param.Value);
-                                                       }
-                                                       else if (ptype.Equals (typeof (Int64))) 
-                                                       {
-                                                               err = Sqlite.sqlite3_bind_int64 (pStmt, i, (Int64)param.Value);
-                                                       } 
-                                                       else 
-                                                       {
-                                                               throw new ApplicationException("Unkown Parameter Type");
-                                                       }
-                                                       if (err != SqliteError.OK) 
-                                                       {
-                                                               throw new ApplicationException ("Sqlite error in bind " + err);
-                                                       }
-                                               }
-                                       }
-                               } while ((int)pzTail - (int)psql < sql.Length);
-                       } finally {
-                               Marshal.FreeCoTaskMem(psql);
-                       }
-                       prepared=true;
+                       prepared = true;
                }
                
-               
                IDbDataParameter IDbCommand.CreateParameter()
                {
                        return CreateParameter ();
@@ -362,7 +414,7 @@ namespace Mono.Data.SqliteClient
                public int ExecuteNonQuery ()
                {
                        int rows_affected;
-                       SqliteDataReader r = ExecuteReader (CommandBehavior.Default, false, out rows_affected);
+                       ExecuteReader (CommandBehavior.Default, false, out rows_affected);
                        return rows_affected;
                }
                
@@ -400,81 +452,92 @@ namespace Mono.Data.SqliteClient
                
                public SqliteDataReader ExecuteReader (CommandBehavior behavior, bool want_results, out int rows_affected)
                {
-                       SqliteDataReader reader = null;
-                       SqliteError err = SqliteError.OK;
-                       IntPtr errMsg = IntPtr.Zero; 
+                       Prepare ();
+                       
+                       // The SQL string may contain multiple sql commands, so the main
+                       // thing to do is have Sqlite iterate through the commands.
+                       // If want_results, only the last command is returned as a
+                       // DataReader.  Otherwise, no command is returned as a
+                       // DataReader.
+               
+                       IntPtr psql; // pointer to SQL command
+                       
+                       // Sqlite 2 docs say this: By default, SQLite assumes that all data uses a fixed-size 8-bit 
+                       // character (iso8859).  But if you give the --enable-utf8 option to the configure script, then the 
+                       // library assumes UTF-8 variable sized characters. This makes a difference for the LIKE and GLOB 
+                       // operators and the LENGTH() and SUBSTR() functions. The static string sqlite_encoding will be set 
+                       // to either "UTF-8" or "iso8859" to indicate how the library was compiled. In addition, the sqlite.h 
+                       // header file will define one of the macros SQLITE_UTF8 or SQLITE_ISO8859, as appropriate.
+                       // 
+                       // We have no way of knowing whether Sqlite 2 expects ISO8859 or UTF-8, but ISO8859 seems to be the
+                       // default.  Therefore, we need to use an ISO8859(-1) compatible encoding, like ANSI.
+                       // OTOH, the user may want to specify the encoding of the bytes stored in the database, regardless
+                       // of what Sqlite is treating them as, 
+                       
+                       // For Sqlite 3, we use the UTF-16 prepare function, so we need a UTF-16 string.
+                       
+                       if (parent_conn.Version == 2)
+                               psql = Sqlite.StringToHeap (sql.Trim(), parent_conn.Encoding);
+                       else
+                               psql = Marshal.StringToHGlobalUni (sql.Trim());
+
+                       IntPtr pzTail = psql;
+                       IntPtr errMsgPtr;
+                       
                        parent_conn.StartExec ();
-                 
-                       try 
-                       {
-                               if (!prepared)
-                               {
-                                       Prepare ();
-                               }
-                               for (int i = 0; i < pStmts.Count; i++) {
-                                       IntPtr pStmt = (IntPtr)pStmts[i];
+                       
+                       rows_affected = 0;
+                       
+                       try {
+                               while (true) {
+                                       IntPtr pStmt;
+                                        
+                                       GetNextStatement(pzTail, out pzTail, out pStmt);
                                        
-                                       // If want_results, return the results of the last statement
-                                       // via the SqliteDataReader, and execute but ignore the results
-                                       // of the other statements.
-                                       if (i == pStmts.Count-1 && want_results) 
-                                       {
-                                               reader = new SqliteDataReader (this, pStmt, parent_conn.Version);
-                                               break;
-                                       } 
+                                       if (pStmt == IntPtr.Zero)
+                                               throw new Exception();
                                        
-                                       // Execute but ignore the results of these statements.
-                                       if (parent_conn.Version == 3) 
-                                       {
-                                               err = Sqlite.sqlite3_step (pStmt);
-                                       }
-                                       else 
-                                       {
-                                               int cols;
-                                               IntPtr pazValue = IntPtr.Zero;
-                                               IntPtr pazColName = IntPtr.Zero;
-                                               err = Sqlite.sqlite_step (pStmt, out cols, out pazValue, out pazColName);
-                                       }
-                                       // On error, misuse, or busy, don't bother with the rest of the statements.
-                                       if (err != SqliteError.ROW && err != SqliteError.DONE) break;
-                               }
-                       }
-                       finally 
-                       {       
-                               foreach (IntPtr pStmt in pStmts) {
-                                       if (parent_conn.Version == 3) 
-                                       {
-                                               err = Sqlite.sqlite3_finalize (pStmt);
-                                       }
-                                       else
-                                       {
-                                               err = Sqlite.sqlite_finalize (pStmt, out errMsg);
+                                       // pzTail is positioned after the last byte in the
+                                       // statement, which will be the NULL character if
+                                       // this was the last statement.
+                                       bool last = Marshal.ReadByte(pzTail) == 0;
+
+                                       try {
+                                               if (parent_conn.Version == 3)
+                                                       BindParameters3 (pStmt);
+                                               
+                                               if (last && want_results)
+                                                       return new SqliteDataReader (this, pStmt, parent_conn.Version);
+
+                                               ExecuteStatement(pStmt);
+                                               
+                                               if (last) // rows_affected is only used if !want_results
+                                                       rows_affected = NumChanges ();
+                                               
+                                       } finally {
+                                               if (parent_conn.Version == 3) 
+                                                       Sqlite.sqlite3_finalize (pStmt);
+                                               else
+                                                       Sqlite.sqlite_finalize (pStmt, out errMsgPtr);
                                        }
+                                       
+                                       if (last) break;
                                }
+
+                               return null;
+                       } finally {
                                parent_conn.EndExec ();
-                               prepared = false;
+                               Marshal.FreeHGlobal (psql);
                        }
-                       
-                       if (err != SqliteError.OK &&
-                           err != SqliteError.DONE &&
-                           err != SqliteError.ROW) 
-                       {
-                               if (errMsg != IntPtr.Zero) 
-                               {
-                                       // TODO: Get the message text
-                               }
-                               throw new ApplicationException ("Sqlite error");
-                       }
-                       rows_affected = NumChanges ();
-                       return reader;
                }
-               
+
                public int LastInsertRowID () 
                {
-                       if (parent_conn.Version == 3)
-                               return Sqlite.sqlite3_last_insert_rowid(parent_conn.Handle);
-                       else
-                               return Sqlite.sqlite_last_insert_rowid(parent_conn.Handle);
+                       return parent_conn.LastInsertRowId;
+               }
+               
+               private string GetError3() {
+                       return Marshal.PtrToStringUni (Sqlite.sqlite3_errmsg16 (parent_conn.Handle));
                }
        #endregion
        }