#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <iostream.h>
#include <fstream.h>

#include "e:\opt\common\timings.h"

int CalculateKeySegment(int Pass, char* InputLine, int LineLength)
{
  int KeySegment;
  unsigned char HighChar;
  unsigned char LowChar;
  
  if (LineLength < Pass)
    KeySegment = 0;
  else if (LineLength == Pass)
    KeySegment = 256 * InputLine[Pass-1];
  else
    {
    HighChar = 0;
    LowChar = InputLine[Pass];
    if (Pass > 0)
      HighChar = InputLine[Pass-1];
    KeySegment = HighChar * 256 + LowChar;
    }

  return KeySegment;
}

int main(int argc, char *argv[])
{
  const int BUFCOUNT = 65536;
  const int TOTAL_BUFFER = 16*1048576;
  const int INPUTLINESIZE = 1024;
  char InputLine[INPUTLINESIZE];
  char* BigBuffer = new char [TOTAL_BUFFER];
  char** Buffer = new char* [BUFCOUNT];
  int* BufferSize = new int[BUFCOUNT];
  int* BufferCharCount = new int[BUFCOUNT];
  int* Displacement = new int[BUFCOUNT];
  int* TotalDisplacement = new int[BUFCOUNT];
  int KeySegment;
  char* InputFileName;
  char* OutputFileName;
  ifstream InputFile;
  ofstream OutputFile;
  int PassCount;
  int CurrentLength;
  int NewLength;
  int LineLength;
  char* OriginalInputFileName;
  char* OriginalOutputFileName;
  int TotalKeys = 0;
  int TotalData = 0;
  bool StatisticsDisplayed = false;
  int TotalWrites = 0;
  int i;
  double BufferRatio;
  int NumberOfPasses = 0;
  int PartialLength;
  int TotalBufferSize;

  if (argc < 4)
    {
    printf("Usage: zensort passcount infile outfile\n");
    exit(1);
    }
  else
    {
    PassCount = atoi(argv[1]);
    OriginalInputFileName = argv[2];
    OriginalOutputFileName = argv[3];
    }

char temp[100];
start_timing();

  for (int Pass = PassCount - 1; Pass >= 0; Pass -=2)
    {

    if ((NumberOfPasses % 2) == 0)
      {
      InputFileName = OriginalInputFileName;
      OutputFileName = OriginalOutputFileName;
      }
    else
      {
      InputFileName = OriginalOutputFileName;
      OutputFileName = OriginalInputFileName;
      }

    NumberOfPasses ++;

    InputFile.open(InputFileName,ios::in|ios::binary);
    OutputFile.open(OutputFileName,ios::out|ios::binary);

    for (i = 0; i < BUFCOUNT; i ++)
      {
      Displacement[i] = 0;
      TotalDisplacement[i] = 0;
      }

    for (i = 0; ; i ++)
      {
      InputFile.getline(InputLine,INPUTLINESIZE);
      if (!InputFile)
        break;
      TotalKeys ++;
      LineLength = strlen(InputLine) + 1;
      strcpy(InputLine+LineLength-1,"\n");
      KeySegment = CalculateKeySegment(Pass,InputLine,LineLength);
      Displacement[KeySegment] += LineLength;
      }
    InputFile.close();

    for (i = 1; i < BUFCOUNT; i ++)
      TotalDisplacement[i] = TotalDisplacement[i-1] + Displacement[i-1];

    if (TotalData == 0)
       {
       for (i = 0; i < BUFCOUNT; i ++)
         TotalData += Displacement[i];
       BufferRatio = (double) TOTAL_BUFFER / TotalData;
       }

    TotalBufferSize = 0;
    for (i = 0; i < BUFCOUNT; i ++)
      {
      BufferSize[i] = (int) (BufferRatio*Displacement[i]);
      Buffer[i] = BigBuffer + TotalBufferSize;
      TotalBufferSize += BufferSize[i];
      BufferCharCount[i] = 0;
      }

    memset(BigBuffer,0,TOTAL_BUFFER);

if ((Pass == PassCount - 1) && StatisticsDisplayed == false)
  {
  printf("Total buffer space: %d\n",TOTAL_BUFFER);
  printf("Total keys: %d\n", TotalKeys);
  printf("Total data: %d\n", TotalData);

  StatisticsDisplayed = true;
  }

sprintf(temp,"Finished counting on pass %d",PassCount-Pass);
timing(temp);

    InputFile.open(InputFileName,ios::in|ios::binary);

    for (i = 0; ; i ++)
      {
      InputFile.getline(InputLine,INPUTLINESIZE);
      if (!InputFile)
        break;
      LineLength = strlen(InputLine)+1;
      strcpy(InputLine+LineLength-1,"\n");
      KeySegment = CalculateKeySegment(Pass,InputLine,LineLength);
      CurrentLength = BufferCharCount[KeySegment];
      if (LineLength > BufferSize[KeySegment])
        {
        OutputFile.seekp(TotalDisplacement[KeySegment]);
        if (CurrentLength > 0) 
          OutputFile.write(Buffer[KeySegment],CurrentLength);
        BufferCharCount[KeySegment] = 0;
        OutputFile.write(InputLine,LineLength);
        TotalDisplacement[KeySegment] += CurrentLength + LineLength;
        TotalWrites ++;
        continue;
        }
      NewLength =  CurrentLength + LineLength;
      if (NewLength >= BufferSize[KeySegment])
        {
        PartialLength = BufferSize[KeySegment] - CurrentLength;
        memcpy(Buffer[KeySegment]+CurrentLength,
          InputLine,PartialLength);
        CurrentLength = BufferSize[KeySegment];
        OutputFile.seekp(TotalDisplacement[KeySegment]);
        OutputFile.write(Buffer[KeySegment],CurrentLength);
        TotalDisplacement[KeySegment] += CurrentLength;
        TotalWrites ++;
        memset(Buffer[KeySegment],0,CurrentLength);
        memcpy(Buffer[KeySegment],InputLine+PartialLength,
          LineLength-PartialLength);
        BufferCharCount[KeySegment] = LineLength - PartialLength;
        }
      else
        {
        memcpy(Buffer[KeySegment]+BufferCharCount[KeySegment],
          InputLine,LineLength);
        BufferCharCount[KeySegment] += LineLength;
        }
      }

    for (i = 0; i < BUFCOUNT; i ++)
      {
      if (Buffer[i])
        {
        CurrentLength = BufferCharCount[i];
        if (CurrentLength > 0)
          {
          OutputFile.seekp(TotalDisplacement[i]);
          OutputFile.write(Buffer[i],CurrentLength);
          TotalWrites ++;
          }
        }
      }
    InputFile.close();
    OutputFile.close();

sprintf(temp,"Finished distributing on pass %d",PassCount-Pass);
timing(temp);
  }

  printf("Total writes: %d\n", TotalWrites);
  
  end_timing();

  return 0;
}

