Completing Go functions

Generate SQL scripts based on a CSV file using Go

You can query a model directly and test the results returned when using different parameter values with the Cloud console, or by calling the Vertex AI API directly.

System instructions

You are a Go developer who specializes in SQL databases.

Prompt Freeform

I have code that converts CSV to SQL. The functions to create the table and insert scripts are not completed. Can you finish these two functions for me?

convert.go

package main

import (
  "encoding/csv"
  "fmt"
  "os"
  "strconv"
  "strings"
  "time"
)

func inferDataType(value string, dialect string) string {
  value = strings.TrimSpace(value) // Trim leading and trailing spaces
  if _, err := strconv.Atoi(value); err == nil {
    return getIntegerType(dialect)
  }
  if _, err := strconv.ParseFloat(value, 64); err == nil {
    return getFloatType(dialect)
  }
  if _, err := time.Parse(time.RFC3339, value); err == nil {
    return getTimestampType(dialect)
  }
  if _, err := time.Parse("2006-01-02 15:04:05", value); err == nil {
    return getDateTimeType(dialect)
  }
  if _, err := time.Parse("2006-01-02", value); err == nil {
    return getDateType(dialect)
  }
  if len(value) > 255 {
    return "TEXT"
  }
  return "VARCHAR(255)"
}

func getIntegerType(dialect string) string {
  switch dialect {
  case "mysql":
    return "INT"
  default:
    return "INTEGER"
  }
}

func getFloatType(dialect string) string {
  switch dialect {
  case "mysql":
    return "FLOAT"
  default:
    return "REAL"
  }
}

func getTimestampType(dialect string) string {
  switch dialect {
  case "postgres":
    return "TIMESTAMP"
  case "mysql":
    return "TIMESTAMP"
  case "sqlite":
    return "DATETIME"
  default:
    return "TEXT"
  }
}

func getDateTimeType(dialect string) string {
  switch dialect {
  case "postgres":
    return "TIMESTAMP"
  case "mysql":
    return "DATETIME"
  case "sqlite":
    return "DATETIME"
  default:
    return "TEXT"
  }
}

func getDateType(dialect string) string {
  switch dialect {
  case "postgres", "mysql", "sqlite":
    return "DATE"
  default:
    return "TEXT"
  }
}

func validateDataType(value string, dataType string) error {
  value = strings.TrimSpace(value)
  switch dataType {
  case "INTEGER":
    if _, err := strconv.Atoi(value); err != nil {
      return err
    }
  case "REAL":
    if _, err := strconv.ParseFloat(value, 64); err != nil {
      return err
    }
  case "TIMESTAMP":
    if _, err := time.Parse(time.RFC3339, value); err != nil {
      return err
    }
  case "DATETIME":
    if _, err := time.Parse("2006-01-02 15:04:05", value); err != nil {
      return err
    }
  case "DATE":
    if _, err := time.Parse("2006-01-02", value); err != nil {
      return err
    }
  }
  return nil
}

func generateCreateTableScript(tableName string, headers []string, sampleData []string, dialect string) string {

}

func generateInsertScript(tableName string, headers []string, rows [][]string, dataTypes []string) (string, error) {

}

func csvToSQL(csvFilePath, tableName, dialect string) (string, string, error) {
  file, err := os.Open(csvFilePath)
  if err != nil {
    return "", "", err
  }
  defer file.Close()

  reader := csv.NewReader(file)
  headers, err := reader.Read()
  if err != nil {
    return "", "", err
  }

  // Convert headers to snake_case
  fixedHeaders := make([]string, len(headers))
  for i, header := range headers {

    fixedHeader := strings.ReplaceAll(strings.TrimSpace(header), " ", "_")
    fixedHeaders[i] = fixedHeader
  }

  sampleData, err := reader.Read()
  if err != nil {
    return "", "", err
  }

  rows, err := reader.ReadAll()
  if err != nil {
    return "", "", err
  }

  dataTypes := make([]string, len(fixedHeaders))
  for i, sample := range sampleData {
    dataTypes[i] = inferDataType(sample, dialect)
  }

  createTableScript := generateCreateTableScript(tableName, fixedHeaders, sampleData, dialect)
  insertScript, err := generateInsertScript(tableName, fixedHeaders, rows, dataTypes)
  if err != nil {
    return "", "", err
  }

  return createTableScript, insertScript, nil
}

main.go

package main

import (
  "flag"
  "fmt"
  "os"
  "path/filepath"
)

func main() {
  csvFilePath := flag.String("f", "", "Path to the CSV file")
  tableName := flag.String("t", "default_table", "Name of the table")
  dialect := flag.String("d", "postgres", "SQL dialect (postgres, mysql, sqlite)")

  flag.Parse()

  // Check if file path is provided
  if *csvFilePath == "" {
    fmt.Println("Error: CSV file path is required")
    flag.Usage()
    os.Exit(1)
  }

  createScript, insertScript, err := csvToSQL(*csvFilePath, *tableName, *dialect)
  if err != nil {
    fmt.Println("Error:", err)
    return
  }

  createFilePath := filepath.Join(".", fmt.Sprintf("%s_create.sql", *tableName))
  err = os.WriteFile(createFilePath, []byte(createScript), 0644)
  if err != nil {
    fmt.Println("Error writing CREATE TABLE script:", err)
    return
  }

  insertFilePath := filepath.Join(".", fmt.Sprintf("%s_insert.sql", *tableName))
  err = os.WriteFile(insertFilePath, []byte(insertScript), 0644)
  if err != nil {
    fmt.Println("Error writing INSERT scripts:", err)
    return
  }

  fmt.Printf("CREATE TABLE script written to: %s\n", createFilePath)
  fmt.Printf("INSERT scripts written to: %s\n", insertFilePath)
}

Response

package main

import (
    "encoding/csv"
    "fmt"
    "os"
    "strconv"
    "strings"
    "time"
)

func inferDataType(value string, dialect string) string {
    value = strings.TrimSpace(value) // Trim leading and trailing spaces
    if _, err := strconv.Atoi(value); err == nil {
        return getIntegerType(dialect)
    }
    if _, err := strconv.ParseFloat(value, 64); err == nil {
        return getFloatType(dialect)
    }
    if _, err := time.Parse(time.RFC3339, value); err == nil {
        return getTimestampType(dialect)
    }
    if _, err := time.Parse("2006-01-02 15:04:05", value); err == nil {
        return getDateTimeType(dialect)
    }
    if _, err := time.Parse("2006-01-02", value); err == nil {
        return getDateType(dialect)
    }
    if len(value) > 255 {
        return "TEXT"
    }
    return "VARCHAR(255)"
}

func getIntegerType(dialect string) string {
    switch dialect {
    case "mysql":
        return "INT"
    default:
        return "INTEGER"
    }
}

func getFloatType(dialect string) string {
    switch dialect {
    case "mysql":
        return "FLOAT"
    default:
        return "REAL"
    }
}

func getTimestampType(dialect string) string {
    switch dialect {
    case "postgres":
        return "TIMESTAMP"
    case "mysql":
        return "TIMESTAMP"
    case "sqlite":
        return "DATETIME"
    default:
        return "TEXT"
    }
}

func getDateTimeType(dialect string) string {
    switch dialect {
    case "postgres":
        return "TIMESTAMP"
    case "mysql":
        return "DATETIME"
    case "sqlite":
        return "DATETIME"
    default:
        return "TEXT"
    }
}

func getDateType(dialect string) string {
    switch dialect {
    case "postgres", "mysql", "sqlite":
        return "DATE"
    default:
        return "TEXT"
    }
}

func validateDataType(value string, dataType string) error {
    value = strings.TrimSpace(value)
    switch dataType {
    case "INTEGER":
        if _, err := strconv.Atoi(value); err != nil {
            return err
        }
    case "REAL":
        if _, err := strconv.ParseFloat(value, 64); err != nil {
            return err
        }
    case "TIMESTAMP":
        if _, err := time.Parse(time.RFC3339, value); err != nil {
            return err
        }
    case "DATETIME":
        if _, err := time.Parse("2006-01-02 15:04:05", value); err != nil {
            return err
        }
    case "DATE":
        if _, err := time.Parse("2006-01-02", value); err != nil {
            return err
        }
    }
    return nil
}

func generateCreateTableScript(tableName string, headers []string, sampleData []string, dialect string) string {
    script := fmt.Sprintf("CREATE TABLE %s (\n", tableName)
    for i, header := range headers {
        dataType := inferDataType(sampleData[i], dialect)
        script += fmt.Sprintf("  %s %s", header, dataType)
        if i < len(headers)-1 {
            script += ",\n"
        } else {
            script += "\n"
        }
    }
    script += ");"
    return script
}

func generateInsertScript(tableName string, headers []string, rows [][]string, dataTypes []string) (string, error) {
    script := ""
    for _, row := range rows {
        if len(row) != len(headers) {
            return "", fmt.Errorf("row length mismatch")
        }
        values := make([]string, len(row))
        for i, value := range row {
            err := validateDataType(value, dataTypes[i])
            if err != nil {
                return "", fmt.Errorf("data type validation failed for value '%s': %w", value, err)
            }
            //Escape strings for SQL injection prevention.  This is a simplification and should be improved for production use.
            if dataTypes[i] == "VARCHAR(255)" || dataTypes[i] == "TEXT" {
                values[i] = "'" + strings.ReplaceAll(value, "'", "''") + "'"
            } else {
                values[i] = value
            }

        }
        script += fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", tableName, strings.Join(headers, ","), strings.Join(values, ","))
    }
    return script, nil
}

func csvToSQL(csvFilePath, tableName, dialect string) (string, string, error) {
    file, err := os.Open(csvFilePath)
    if err != nil {
        return "", "", err
    }
    defer file.Close()

    reader := csv.NewReader(file)
    headers, err := reader.Read()
    if err != nil {
        return "", "", err
    }

    // Convert headers to snake_case
    fixedHeaders := make([]string, len(headers))
    for i, header := range headers {

        fixedHeader := strings.ReplaceAll(strings.TrimSpace(header), " ", "_")
        fixedHeaders[i] = fixedHeader
    }

    sampleData, err := reader.Read()
    if err != nil {
        return "", "", err
    }

    rows, err := reader.ReadAll()
    if err != nil {
        return "", "", err
    }

    dataTypes := make([]string, len(fixedHeaders))
    for i, sample := range sampleData {
        dataTypes[i] = inferDataType(sample, dialect)
    }

    createTableScript := generateCreateTableScript(tableName, fixedHeaders, sampleData, dialect)
    insertScript, err := generateInsertScript(tableName, fixedHeaders, rows, dataTypes)
    if err != nil {
        return "", "", err
    }

    return createTableScript, insertScript, nil
}

The updated code includes the completed generateCreateTableScript and generateInsertScript functions. generateCreateTableScript constructs a CREATE TABLE statement based on the inferred data types. generateInsertScript generates INSERT statements for each row in the CSV, including basic SQL escaping to prevent simple SQL injection. Note: The escaping in generateInsertScript is rudimentary and should be replaced with a proper parameterized query mechanism for production use to prevent more sophisticated SQL injection attacks. The error handling is also improved to provide more informative error messages. Remember to handle potential errors appropriately in your main function.

Model: gemini-1.5-flash-002
Temperature: 0.2
Max output tokens: 8192
TopK: 40.0
TopP: 0.95