#include <stdlib.h>
#include <string.h>
#include <hfs/hfsplus.h>

int writeExtents(RawFile* rawFile);

int isBlockUsed(Volume* volume, uint32_t block)
{
  unsigned char byte;
  
  READ(volume->allocationFile, block / 8, 1, &byte);
  return (byte & (1 << (7 - (block % 8)))) != 0;
}

int setBlockUsed(Volume* volume, uint32_t block, int used) {
  unsigned char byte;
  
  READ(volume->allocationFile, block / 8, 1, &byte);
  if(used) {
    byte |= (1 << (7 - (block % 8)));
  } else {
    byte &= ~(1 << (7 - (block % 8)));
  }
  ASSERT(WRITE(volume->allocationFile, block / 8, 1, &byte), "WRITE");
  
  return TRUE;
}

int allocate(RawFile* rawFile, off_t size) {
  unsigned char* zeros;
  Volume* volume;
  HFSPlusForkData* forkData;
  uint32_t blocksNeeded;
  uint32_t blocksToAllocate;
  Extent* extent;
  Extent* lastExtent;
  
  uint32_t curBlock;
  
  volume = rawFile->volume;
  forkData = rawFile->forkData;
  extent = rawFile->extents;
 
  blocksNeeded = ((uint64_t)size / (uint64_t)volume->volumeHeader->blockSize) + (((size % volume->volumeHeader->blockSize) == 0) ? 0 : 1);
  
  if(blocksNeeded > forkData->totalBlocks) {
    zeros = (unsigned char*) malloc(volume->volumeHeader->blockSize);
    memset(zeros, 0, volume->volumeHeader->blockSize);
  
    blocksToAllocate = blocksNeeded - forkData->totalBlocks;
    
    if(blocksToAllocate > volume->volumeHeader->freeBlocks) {
      return FALSE;
    }
    
    lastExtent = NULL;
    while(extent != NULL) {
      lastExtent = extent;
      extent = extent->next;
    }
    
    if(lastExtent == NULL) {
      rawFile->extents = (Extent*) malloc(sizeof(Extent));
      lastExtent = rawFile->extents;
      lastExtent->blockCount = 0;
      lastExtent->next = NULL;
      curBlock = volume->volumeHeader->nextAllocation;
    } else {
      curBlock = lastExtent->startBlock + lastExtent->blockCount;
    }
    
    while(blocksToAllocate > 0) {
      if(isBlockUsed(volume, curBlock)) {
        if(lastExtent->blockCount > 0) {
          lastExtent->next = (Extent*) malloc(sizeof(Extent));
          lastExtent = lastExtent->next;
          lastExtent->blockCount = 0;
          lastExtent->next = NULL;
        }
        curBlock = volume->volumeHeader->nextAllocation;
        volume->volumeHeader->nextAllocation++;
        if(volume->volumeHeader->nextAllocation >= volume->volumeHeader->totalBlocks) {
          volume->volumeHeader->nextAllocation = 0;
        }
      } else {
        if(lastExtent->blockCount == 0) {
          lastExtent->startBlock = curBlock;
        }

        /* zero out allocated block */
        ASSERT(WRITE(volume->image, curBlock * volume->volumeHeader->blockSize, volume->volumeHeader->blockSize, zeros), "WRITE");

        setBlockUsed(volume, curBlock, TRUE);
        volume->volumeHeader->freeBlocks--;
        blocksToAllocate--;
        curBlock++;
        lastExtent->blockCount++;
        
        if(curBlock >= volume->volumeHeader->totalBlocks) {
          curBlock = volume->volumeHeader->nextAllocation;
        }
      }
    }
    
    free(zeros);
  } else if(blocksNeeded < forkData->totalBlocks) {
    blocksToAllocate = blocksNeeded;
  
    lastExtent = NULL;
    
    while(blocksToAllocate > 0) {
      if(blocksToAllocate > extent->blockCount) {
        blocksToAllocate -= extent->blockCount;
        lastExtent = extent;
        extent = extent->next;
      } else {
        break;
      }
    }
    
    
    if(blocksToAllocate == 0 && lastExtent != NULL) {
      lastExtent->next = NULL;
    } else {
      rawFile->extents = NULL;
    }
    
    do {
      for(curBlock = (extent->startBlock + blocksToAllocate); curBlock < (extent->startBlock + extent->blockCount); curBlock++) {
        setBlockUsed(volume, curBlock, FALSE);
        volume->volumeHeader->freeBlocks++;
      }
      blocksToAllocate = 0;
      lastExtent = extent;
      extent = extent->next;
      free(lastExtent);
      
    } while(extent != NULL);
  }

  writeExtents(rawFile);
  
  forkData->logicalSize = size;
  forkData->totalBlocks = blocksNeeded;

  updateVolume(rawFile->volume);
  
  if(rawFile->catalogRecord != NULL) {
    updateCatalog(rawFile->volume, rawFile->catalogRecord);
  }
  
  return TRUE;
}

static int rawFileRead(io_func* io,off_t location, size_t size, void *buffer) {
  RawFile* rawFile;
  Volume* volume;
  Extent* extent;
  
  size_t blockSize;
  off_t fileLoc;
  off_t locationInBlock;
  size_t possible;
  
  rawFile = (RawFile*) io->data;
  volume = rawFile->volume;
  blockSize = volume->volumeHeader->blockSize;
  
  extent = rawFile->extents;
  fileLoc = 0;
  
  locationInBlock = location;
  while(TRUE) {
    fileLoc += extent->blockCount * blockSize;
    if(fileLoc <= location) {
      locationInBlock -= extent->blockCount * blockSize;
      extent = extent->next;
      if(extent == NULL)
        break;
    } else {
      break;
    }
  }
  
  while(size > 0) {
    if(extent == NULL)
      return FALSE;
      
    possible = extent->blockCount * blockSize - locationInBlock;
        
    if(size > possible) {
      ASSERT(READ(volume->image, extent->startBlock * blockSize + locationInBlock, possible, buffer), "READ");
      size -= possible;
      buffer = (void*)(((size_t)buffer) + possible);
      extent = extent->next;
    } else {
      ASSERT(READ(volume->image, extent->startBlock * blockSize + locationInBlock, size, buffer), "READ");
      break;
    }
    
    locationInBlock = 0;
  }
   
  return TRUE;
}

static int rawFileWrite(io_func* io,off_t location, size_t size, void *buffer) {
  RawFile* rawFile;
  Volume* volume;
  Extent* extent;
  
  size_t blockSize;
  off_t fileLoc;
  off_t locationInBlock;
  size_t possible;
  
  rawFile = (RawFile*) io->data;
  volume = rawFile->volume;
  blockSize = volume->volumeHeader->blockSize;
  
  if(rawFile->forkData->logicalSize < (location + size)) {
    ASSERT(allocate(rawFile, location + size), "allocate");
  }
  
  extent = rawFile->extents;
  fileLoc = 0;
  
  locationInBlock = location;
  while(TRUE) {
    fileLoc += extent->blockCount * blockSize;
    if(fileLoc <= location) {
      locationInBlock -= extent->blockCount * blockSize;
      extent = extent->next;
      if(extent == NULL)
        break;
    } else {
      break;
    }
  }
  
  while(size > 0) {
    if(extent == NULL)
      return FALSE;
      
    possible = extent->blockCount * blockSize - locationInBlock;
        
    if(size > possible) {
      ASSERT(WRITE(volume->image, extent->startBlock * blockSize + locationInBlock, possible, buffer), "WRITE");
      size -= possible;
      buffer = (void*)(((size_t)buffer) + possible);
      extent = extent->next;
    } else {
      ASSERT(WRITE(volume->image, extent->startBlock * blockSize + locationInBlock, size, buffer), "WRITE");
      break;
    }
    
    locationInBlock = 0;
  }
   
  return TRUE;
}

static void closeRawFile(io_func* io) {
  RawFile* rawFile;
  Extent* extent;
  Extent* toRemove;
  
  rawFile = (RawFile*) io->data;
  extent = rawFile->extents;
  
  while(extent != NULL) {
    toRemove = extent;
    extent = extent->next;
    free(toRemove);
  }
  
  free(rawFile);
  free(io);
}

int removeExtents(RawFile* rawFile) {
  uint32_t blocksLeft;
  HFSPlusForkData* forkData;
  uint32_t currentBlock;
  
  uint32_t startBlock;
  uint32_t blockCount;
    
  HFSPlusExtentDescriptor* descriptor;
  int currentExtent;
  HFSPlusExtentKey extentKey;
  int exact;
  
  extentKey.keyLength = sizeof(HFSPlusExtentKey) - sizeof(extentKey.keyLength);
  extentKey.forkType = 0;
  extentKey.fileID = rawFile->id;
  
  forkData = rawFile->forkData;
  blocksLeft = forkData->totalBlocks;
  currentExtent = 0;
  currentBlock = 0;
  descriptor = (HFSPlusExtentDescriptor*) forkData->extents;
  
  while(blocksLeft > 0) {       
    if(currentExtent == 8) {
      if(rawFile->volume->extentsTree == NULL) {
        panic("no extents overflow file loaded yet!");
        return FALSE;
      }
      
      if(descriptor != ((HFSPlusExtentDescriptor*) forkData->extents)) {
        free(descriptor);
      }
      
      extentKey.startBlock = currentBlock;
      descriptor = (HFSPlusExtentDescriptor*) search(rawFile->volume->extentsTree, (BTKey*)(&extentKey), &exact, NULL, NULL);
      if(descriptor == NULL || exact == FALSE) {
        panic("inconsistent extents information!");
        return FALSE;
      } else {
        removeFromBTree(rawFile->volume->extentsTree, (BTKey*)(&extentKey));
        currentExtent = 0;
        continue;
      }
    }

    startBlock = descriptor[currentExtent].startBlock;
    blockCount = descriptor[currentExtent].blockCount;
    
    currentBlock += blockCount;
    blocksLeft -= blockCount;
    currentExtent++;
  }
  
  if(descriptor != ((HFSPlusExtentDescriptor*) forkData->extents)) {
    free(descriptor);
  }
  
  return TRUE;
}

int writeExtents(RawFile* rawFile) {
  Extent* extent;
  int currentExtent;
  HFSPlusExtentKey extentKey;
  HFSPlusExtentDescriptor descriptor[8];
  HFSPlusForkData* forkData;
  
  removeExtents(rawFile);
  
  forkData = rawFile->forkData;
  currentExtent = 0;
  extent = rawFile->extents;
  
  memset(forkData->extents, 0, sizeof(HFSPlusExtentRecord));
  while(extent != NULL && currentExtent < 8) {
    ((HFSPlusExtentDescriptor*)forkData->extents)[currentExtent].startBlock = extent->startBlock;
    ((HFSPlusExtentDescriptor*)forkData->extents)[currentExtent].blockCount = extent->blockCount;
    extent = extent->next;
    currentExtent++;
  }
  
  if(extent != NULL) {
    extentKey.keyLength = sizeof(HFSPlusExtentKey) - sizeof(extentKey.keyLength);
    extentKey.forkType = 0;
    extentKey.fileID = rawFile->id;
    
    currentExtent = 0;
    
    while(extent != NULL) {
      if(currentExtent == 0) {
        memset(descriptor, 0, sizeof(HFSPlusExtentRecord));
      }
      
      if(currentExtent == 8) {
        extentKey.startBlock = descriptor[0].startBlock;
        addToBTree(rawFile->volume->extentsTree, (BTKey*)(&extentKey), sizeof(HFSPlusExtentRecord), (unsigned char *)(&(descriptor[0])));
        currentExtent = 0;
      }
      
      descriptor[currentExtent].startBlock = extent->startBlock;
      descriptor[currentExtent].blockCount = extent->blockCount;
      
      currentExtent++;
      extent = extent->next;
    }
       
    extentKey.startBlock = descriptor[0].startBlock;
    addToBTree(rawFile->volume->extentsTree, (BTKey*)(&extentKey), sizeof(HFSPlusExtentRecord), (unsigned char *)(&(descriptor[0])));
  }
  
  return TRUE;
}

int readExtents(RawFile* rawFile) {
  uint32_t blocksLeft;
  HFSPlusForkData* forkData;
  uint32_t currentBlock;
  
  Extent* extent;
  Extent* lastExtent;
  
  HFSPlusExtentDescriptor* descriptor;
  int currentExtent;
  HFSPlusExtentKey extentKey;
  int exact;
  
  extentKey.keyLength = sizeof(HFSPlusExtentKey) - sizeof(extentKey.keyLength);
  extentKey.forkType = 0;
  extentKey.fileID = rawFile->id;
  
  forkData = rawFile->forkData;
  blocksLeft = forkData->totalBlocks;
  currentExtent = 0;
  currentBlock = 0;
  descriptor = (HFSPlusExtentDescriptor*) forkData->extents;
  
  lastExtent = NULL;
  
  while(blocksLeft > 0) {
    extent = (Extent*) malloc(sizeof(Extent));
       
    if(currentExtent == 8) {
      if(rawFile->volume->extentsTree == NULL) {
        panic("no extents overflow file loaded yet!");
        return FALSE;
      }
      
      if(descriptor != ((HFSPlusExtentDescriptor*) forkData->extents)) {
        free(descriptor);
      }
      
      extentKey.startBlock = currentBlock;
      descriptor = (HFSPlusExtentDescriptor*) search(rawFile->volume->extentsTree, (BTKey*)(&extentKey), &exact, NULL, NULL);
      if(descriptor == NULL || exact == FALSE) {
        panic("inconsistent extents information!");
        return FALSE;
      } else {
        currentExtent = 0;
        continue;
      }
    }
    
    extent->startBlock = descriptor[currentExtent].startBlock;
    extent->blockCount = descriptor[currentExtent].blockCount;
    extent->next = NULL;
    
    currentBlock += extent->blockCount;
    blocksLeft -= extent->blockCount;
    currentExtent++;
    
    if(lastExtent == NULL) {
      rawFile->extents = extent;
    } else {
      lastExtent->next = extent;
    }
    
    lastExtent = extent;
  }
  
  if(descriptor != ((HFSPlusExtentDescriptor*) forkData->extents)) {
    free(descriptor);
  }
  
  return TRUE;
}

io_func* openRawFile(HFSCatalogNodeID id, HFSPlusForkData* forkData, HFSPlusCatalogRecord* catalogRecord, Volume* volume) {
  io_func* io;
  RawFile* rawFile;
  
  io = (io_func*) malloc(sizeof(io_func));
  rawFile = (RawFile*) malloc(sizeof(RawFile));
  
  rawFile->id = id;
  rawFile->volume = volume;
  rawFile->forkData = forkData;
  rawFile->catalogRecord = catalogRecord;
  rawFile->extents = NULL;
  
  io->data = rawFile;
  io->read = &rawFileRead;
  io->write = &rawFileWrite;
  io->close = &closeRawFile;
  
  if(!readExtents(rawFile)) {
    return NULL;
  }

  return io;
}
