Go 함수 완성

Go를 사용하여 CSV 파일을 기반으로 SQL 스크립트 생성

모델을 직접 쿼리하고 Cloud 콘솔에서 다른 매개변수 값을 사용하거나 Vertex AI API를 직접 호출하여 반환된 결과를 테스트할 수 있습니다.

시스템 안내

SQL 데이터베이스를 전문으로 하는 Go 개발자입니다.

자유 형식 프롬프트

CSV를 SQL로 변환하는 코드가 있습니다. 테이블을 만들고 스크립트를 삽입하는 함수가 완료되지 않았습니다. 이 두 가지 함수를 완료해 주시겠어요?

convert.go

package main

import (
  "encoding/csv"
  "fmt"
  "os"
  "strconv"
  "strings"
  "time"
)

func inferDataType(value string, dialect string) string {
  value = strings.TrimSpace(value) // Trim leading and trailing spaces
  if _, err := strconv.Atoi(value); err == nil {
    return getIntegerType(dialect)
  }
  if _, err := strconv.ParseFloat(value, 64); err == nil {
    return getFloatType(dialect)
  }
  if _, err := time.Parse(time.RFC3339, value); err == nil {
    return getTimestampType(dialect)
  }
  if _, err := time.Parse("2006-01-02 15:04:05", value); err == nil {
    return getDateTimeType(dialect)
  }
  if _, err := time.Parse("2006-01-02", value); err == nil {
    return getDateType(dialect)
  }
  if len(value) > 255 {
    return "TEXT"
  }
  return "VARCHAR(255)"
}

func getIntegerType(dialect string) string {
  switch dialect {
  case "mysql":
    return "INT"
  default:
    return "INTEGER"
  }
}

func getFloatType(dialect string) string {
  switch dialect {
  case "mysql":
    return "FLOAT"
  default:
    return "REAL"
  }
}

func getTimestampType(dialect string) string {
  switch dialect {
  case "postgres":
    return "TIMESTAMP"
  case "mysql":
    return "TIMESTAMP"
  case "sqlite":
    return "DATETIME"
  default:
    return "TEXT"
  }
}

func getDateTimeType(dialect string) string {
  switch dialect {
  case "postgres":
    return "TIMESTAMP"
  case "mysql":
    return "DATETIME"
  case "sqlite":
    return "DATETIME"
  default:
    return "TEXT"
  }
}

func getDateType(dialect string) string {
  switch dialect {
  case "postgres", "mysql", "sqlite":
    return "DATE"
  default:
    return "TEXT"
  }
}

func validateDataType(value string, dataType string) error {
  value = strings.TrimSpace(value)
  switch dataType {
  case "INTEGER":
    if _, err := strconv.Atoi(value); err != nil {
      return err
    }
  case "REAL":
    if _, err := strconv.ParseFloat(value, 64); err != nil {
      return err
    }
  case "TIMESTAMP":
    if _, err := time.Parse(time.RFC3339, value); err != nil {
      return err
    }
  case "DATETIME":
    if _, err := time.Parse("2006-01-02 15:04:05", value); err != nil {
      return err
    }
  case "DATE":
    if _, err := time.Parse("2006-01-02", value); err != nil {
      return err
    }
  }
  return nil
}

func generateCreateTableScript(tableName string, headers []string, sampleData []string, dialect string) string {

}

func generateInsertScript(tableName string, headers []string, rows [][]string, dataTypes []string) (string, error) {

}

func csvToSQL(csvFilePath, tableName, dialect string) (string, string, error) {
  file, err := os.Open(csvFilePath)
  if err != nil {
    return "", "", err
  }
  defer file.Close()

  reader := csv.NewReader(file)
  headers, err := reader.Read()
  if err != nil {
    return "", "", err
  }

  // Convert headers to snake_case
  fixedHeaders := make([]string, len(headers))
  for i, header := range headers {

    fixedHeader := strings.ReplaceAll(strings.TrimSpace(header), " ", "_")
    fixedHeaders[i] = fixedHeader
  }

  sampleData, err := reader.Read()
  if err != nil {
    return "", "", err
  }

  rows, err := reader.ReadAll()
  if err != nil {
    return "", "", err
  }

  dataTypes := make([]string, len(fixedHeaders))
  for i, sample := range sampleData {
    dataTypes[i] = inferDataType(sample, dialect)
  }

  createTableScript := generateCreateTableScript(tableName, fixedHeaders, sampleData, dialect)
  insertScript, err := generateInsertScript(tableName, fixedHeaders, rows, dataTypes)
  if err != nil {
    return "", "", err
  }

  return createTableScript, insertScript, nil
}

main.go

package main

import (
  "flag"
  "fmt"
  "os"
  "path/filepath"
)

func main() {
  csvFilePath := flag.String("f", "", "Path to the CSV file")
  tableName := flag.String("t", "default_table", "Name of the table")
  dialect := flag.String("d", "postgres", "SQL dialect (postgres, mysql, sqlite)")

  flag.Parse()

  // Check if file path is provided
  if *csvFilePath == "" {
    fmt.Println("Error: CSV file path is required")
    flag.Usage()
    os.Exit(1)
  }

  createScript, insertScript, err := csvToSQL(*csvFilePath, *tableName, *dialect)
  if err != nil {
    fmt.Println("Error:", err)
    return
  }

  createFilePath := filepath.Join(".", fmt.Sprintf("%s_create.sql", *tableName))
  err = os.WriteFile(createFilePath, []byte(createScript), 0644)
  if err != nil {
    fmt.Println("Error writing CREATE TABLE script:", err)
    return
  }

  insertFilePath := filepath.Join(".", fmt.Sprintf("%s_insert.sql", *tableName))
  err = os.WriteFile(insertFilePath, []byte(insertScript), 0644)
  if err != nil {
    fmt.Println("Error writing INSERT scripts:", err)
    return
  }

  fmt.Printf("CREATE TABLE script written to: %s\n", createFilePath)
  fmt.Printf("INSERT scripts written to: %s\n", insertFilePath)
}

응답

package main

import (
    "encoding/csv"
    "fmt"
    "os"
    "strconv"
    "strings"
    "time"
)

func inferDataType(value string, dialect string) string {
    value = strings.TrimSpace(value) // Trim leading and trailing spaces
    if _, err := strconv.Atoi(value); err == nil {
        return getIntegerType(dialect)
    }
    if _, err := strconv.ParseFloat(value, 64); err == nil {
        return getFloatType(dialect)
    }
    if _, err := time.Parse(time.RFC3339, value); err == nil {
        return getTimestampType(dialect)
    }
    if _, err := time.Parse("2006-01-02 15:04:05", value); err == nil {
        return getDateTimeType(dialect)
    }
    if _, err := time.Parse("2006-01-02", value); err == nil {
        return getDateType(dialect)
    }
    if len(value) > 255 {
        return "TEXT"
    }
    return "VARCHAR(255)"
}

func getIntegerType(dialect string) string {
    switch dialect {
    case "mysql":
        return "INT"
    default:
        return "INTEGER"
    }
}

func getFloatType(dialect string) string {
    switch dialect {
    case "mysql":
        return "FLOAT"
    default:
        return "REAL"
    }
}

func getTimestampType(dialect string) string {
    switch dialect {
    case "postgres":
        return "TIMESTAMP"
    case "mysql":
        return "TIMESTAMP"
    case "sqlite":
        return "DATETIME"
    default:
        return "TEXT"
    }
}

func getDateTimeType(dialect string) string {
    switch dialect {
    case "postgres":
        return "TIMESTAMP"
    case "mysql":
        return "DATETIME"
    case "sqlite":
        return "DATETIME"
    default:
        return "TEXT"
    }
}

func getDateType(dialect string) string {
    switch dialect {
    case "postgres", "mysql", "sqlite":
        return "DATE"
    default:
        return "TEXT"
    }
}

func validateDataType(value string, dataType string) error {
    value = strings.TrimSpace(value)
    switch dataType {
    case "INTEGER":
        if _, err := strconv.Atoi(value); err != nil {
            return err
        }
    case "REAL":
        if _, err := strconv.ParseFloat(value, 64); err != nil {
            return err
        }
    case "TIMESTAMP":
        if _, err := time.Parse(time.RFC3339, value); err != nil {
            return err
        }
    case "DATETIME":
        if _, err := time.Parse("2006-01-02 15:04:05", value); err != nil {
            return err
        }
    case "DATE":
        if _, err := time.Parse("2006-01-02", value); err != nil {
            return err
        }
    }
    return nil
}

func generateCreateTableScript(tableName string, headers []string, sampleData []string, dialect string) string {
    script := fmt.Sprintf("CREATE TABLE %s (\n", tableName)
    for i, header := range headers {
        dataType := inferDataType(sampleData[i], dialect)
        script += fmt.Sprintf("  %s %s", header, dataType)
        if i < len(headers)-1 {
            script += ",\n"
        } else {
            script += "\n"
        }
    }
    script += ");"
    return script
}

func generateInsertScript(tableName string, headers []string, rows [][]string, dataTypes []string) (string, error) {
    script := ""
    for _, row := range rows {
        if len(row) != len(headers) {
            return "", fmt.Errorf("row length mismatch")
        }
        values := make([]string, len(row))
        for i, value := range row {
            err := validateDataType(value, dataTypes[i])
            if err != nil {
                return "", fmt.Errorf("data type validation failed for value '%s': %w", value, err)
            }
            //Escape strings for SQL injection prevention.  This is a simplification and should be improved for production use.
            if dataTypes[i] == "VARCHAR(255)" || dataTypes[i] == "TEXT" {
                values[i] = "'" + strings.ReplaceAll(value, "'", "''") + "'"
            } else {
                values[i] = value
            }

        }
        script += fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", tableName, strings.Join(headers, ","), strings.Join(values, ","))
    }
    return script, nil
}

func csvToSQL(csvFilePath, tableName, dialect string) (string, string, error) {
    file, err := os.Open(csvFilePath)
    if err != nil {
        return "", "", err
    }
    defer file.Close()

    reader := csv.NewReader(file)
    headers, err := reader.Read()
    if err != nil {
        return "", "", err
    }

    // Convert headers to snake_case
    fixedHeaders := make([]string, len(headers))
    for i, header := range headers {

        fixedHeader := strings.ReplaceAll(strings.TrimSpace(header), " ", "_")
        fixedHeaders[i] = fixedHeader
    }

    sampleData, err := reader.Read()
    if err != nil {
        return "", "", err
    }

    rows, err := reader.ReadAll()
    if err != nil {
        return "", "", err
    }

    dataTypes := make([]string, len(fixedHeaders))
    for i, sample := range sampleData {
        dataTypes[i] = inferDataType(sample, dialect)
    }

    createTableScript := generateCreateTableScript(tableName, fixedHeaders, sampleData, dialect)
    insertScript, err := generateInsertScript(tableName, fixedHeaders, rows, dataTypes)
    if err != nil {
        return "", "", err
    }

    return createTableScript, insertScript, nil
}

업데이트된 코드에는 완성된 generateCreateTableScriptgenerateInsertScript 함수가 포함되어 있습니다. generateCreateTableScript는 추론된 데이터 유형을 기반으로 CREATE TABLE 문을 생성합니다. generateInsertScript는 간단한 SQL 삽입을 방지하기 위한 기본 SQL 이스케이프를 포함하여 CSV의 각 행에 대한 INSERT 문을 생성합니다. 참고: generateInsertScript의 이스케이핑은 초보적인 수준이며 더 정교한 SQL 삽입 공격을 방지하려면 프로덕션용으로 적절한 매개변수화된 쿼리 메커니즘으로 대체해야 합니다. 또한 더 많은 정보를 제공하는 오류 메시지를 제공하도록 오류 처리가 개선되었습니다. 기본 함수에서 잠재적 오류를 적절하게 처리해야 합니다.

모델: gemini-1.5-flash-002
강도: 0.2
최대 출력 토큰: 8192
TopK: 40.0
최상위 P: 0.95