/*
** File: lock.c  
** Project: ADMXRC2 module driver
** Purpose: OS-independent IOCTL handlers relating to locking down user-space buffers.
**
** (C) Copyright Alpha Data 2013
*/

#include <df.h>
#include "device.h"
#include "lock.h"

#if defined(ADMXRC2_LOCKED_BUFFERS_GLOBAL)
static UserBufferMappings g_lockedBuffers;
#endif

DF_DECLARE_INLINE_FUNC(UserBufferMappings*, getUserBufferMappings)(
  Admxrc2ClientContext* pClCtx)
{
#if defined(ADMXRC2_LOCKED_BUFFERS_GLOBAL)
  return &g_lockedBuffers;
#else
  return &pClCtx->lockedBuffers;
#endif
}

/* Returns FALSE if no free handles. */
static boolean_t
getHandle(
  Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  uint32_t* pNewHandle)
{
  UserBufferMappings* pLockedBuffers = getUserBufferMappings(pClCtx);
  DfSpinLockFlags f;
  uint32_t handle;
  unsigned int top;

  f = dfSpinLockGet(&pLockedBuffers->lock);
  if (pLockedBuffers->stack.top == MAX_NUM_LOCKED_USER_SPACE_BUFFER) {
    dfSpinLockPut(&pLockedBuffers->lock, f);
    return FALSE;
  }
  top = pLockedBuffers->stack.top;
  handle = pLockedBuffers->stack.free[top];
  pLockedBuffers->stack.top = top + 1;
  dfSpinLockPut(&pLockedBuffers->lock, f);

  *pNewHandle = handle;
  return TRUE;
}

static void
putHandle(
  Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  uint32_t handle)
{
  UserBufferMappings* pLockedBuffers = getUserBufferMappings(pClCtx);
  DfSpinLockFlags f;
  unsigned int top;

  f = dfSpinLockGet(&pLockedBuffers->lock);
  top = pLockedBuffers->stack.top - 1U;
  pLockedBuffers->stack.free[top] = handle;
  pLockedBuffers->stack.top = top;
  dfSpinLockPut(&pLockedBuffers->lock, f);
}

/* Makes a handle valid and binds it to a BufferDescription* */
static void
bindHandle(
  Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  uint32_t handle,
  BufferDescription* pDescription)
{
  UserBufferMappings* pLockedBuffers = getUserBufferMappings(pClCtx);
  DfSpinLockFlags f;

  f = dfSpinLockGet(&pLockedBuffers->lock);
  pDescription->refCount = 1;
  pDescription->handle = handle;
  pLockedBuffers->map[handle] = pDescription;
  dfSpinLockPut(&pLockedBuffers->lock, f);
}

/*
** Makes a handle invalid and unbinds it from the specified BufferDescription.
** Returns TRUE if the reference count went to 0 and the BufferDescription was freed.
*/
static boolean_t
unbindHandle(
  Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  BufferDescription* pDescription)
{
  UserBufferMappings* pLockedBuffers = getUserBufferMappings(pClCtx);
  DfSpinLockFlags f;
  uint32_t handle, refCount;
  boolean_t bMustFree = FALSE;
  unsigned int top;

  handle = pDescription->handle;
  f = dfSpinLockGet(&pLockedBuffers->lock);
  pLockedBuffers->map[handle] = NULL;
  top = pLockedBuffers->stack.top - 1;
  refCount = pDescription->refCount - 1;
  pLockedBuffers->stack.free[top] = handle;
  pLockedBuffers->stack.top = top;
  pDescription->refCount = refCount;
  if (0 == refCount) {
    bMustFree = TRUE;
  }
  dfSpinLockPut(&pLockedBuffers->lock, f);

  if (bMustFree) {
    dfDebugPrint(3, ("unbindHandle: freeing, handle=%lu\n", (unsigned long)handle + 1));
    dfBufferDescriptionUnlock(pDescription->pDfBufferDesc);
    dfBufferDescriptionDelete(pDescription->pDfBufferDesc);
    dfFree(pDescription);
    return TRUE;
  } else {
    return FALSE;
  }
}

/*
** Increments a handle's reference count and returns the underlying BufferDescription.
** Returns NULL if no such valid handle exists.
*/
BufferDescription*
referenceBufferHandle(
  Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  uint32_t handle)
{
  UserBufferMappings* pLockedBuffers = getUserBufferMappings(pClCtx);
  BufferDescription* pDescription;
  DfSpinLockFlags f;

  if (handle >= MAX_NUM_LOCKED_USER_SPACE_BUFFER) {
    return NULL;
  }

  f = dfSpinLockGet(&pLockedBuffers->lock);
  pDescription = pLockedBuffers->map[handle];
  if (NULL != pDescription) {
    pDescription->refCount++;
  }
  dfSpinLockPut(&pLockedBuffers->lock, f);

  return pDescription;
}

/*
** Decrements the reference count for the specified BufferDescription and frees if the reference count goes to 0.
** Returns TRUE if the reference count went to 0 and the BufferDescription was freed.
*/
boolean_t
dereferenceBufferDescription(
  Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  BufferDescription* pDescription)
{
  UserBufferMappings* pLockedBuffers = getUserBufferMappings(pClCtx);
  DfSpinLockFlags f;
  boolean_t bMustFree = FALSE;
  uint32_t handle, refCount;

  handle = pDescription->handle;
  f = dfSpinLockGet(&pLockedBuffers->lock);
  refCount = pDescription->refCount - 1;
  if (0 == refCount) {
    bMustFree = TRUE;
  }
  pDescription->refCount = refCount;
  dfSpinLockPut(&pLockedBuffers->lock, f);

  if (bMustFree) {
    dfDebugPrint(3, ("dereferenceBufferDescription: freeing, refCount=%lu handle=%lu\n", (unsigned long)refCount, (unsigned long)handle + 1));
    dfBufferDescriptionUnlock(pDescription->pDfBufferDesc);
    dfBufferDescriptionDelete(pDescription->pDfBufferDesc);
    dfFree(pDescription);
    return TRUE;
  } else {
    dfDebugPrint(6, ("dereferenceBufferDescription: not freeing, refCount=%lu handle=%lu\n", (unsigned long)refCount, (unsigned long)handle + 1));
    return FALSE;
  }
}

void
cleanupLockedBuffers(
	Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx)
{
  UserBufferMappings* pLockedBuffers = getUserBufferMappings(pClCtx);
  unsigned int i;
  BufferDescription* pDescription;

  for (i = 0; i < MAX_NUM_LOCKED_USER_SPACE_BUFFER; i++) {
    pDescription = pLockedBuffers->map[i];
    if (NULL != pDescription && pDescription->pCreator == pClCtx) {
      dfDebugPrint(1, ("cleanupLockedBuffers: i=%lu unbinding handle %lu, refCount=%lu\n",
        (unsigned long)i, (unsigned long)pDescription->handle + 1, (unsigned long)pDescription->refCount));
      unbindHandle(pDevCtx, pClCtx, pDescription);
    }
  }
}

DfIoStatus
ioctlLock(
	Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  void* pBuffer,
  unsigned int inSize,
  unsigned int outSize)
{
  IOCTLS_ADMXRC2_LOCK* pIoctl = (IOCTLS_ADMXRC2_LOCK*)pBuffer;
  uint32_t handle;
  BufferDescription* pNewDescription;
  DfBufferDescription* pDfBufferDesc = NULL;
  DfDescriptionResult dfRes;

  if (NULL == pIoctl || sizeof(pIoctl->in) != inSize || sizeof(pIoctl->out) != outSize) {
    return DfIoStatusInvalid;
  }
  if (!admxrc2CheckAccessRights(pClCtx)) {
    return DfIoStatusError(ADMXRC2_ACCESS_DENIED);
  }

  pNewDescription = (BufferDescription*)dfMalloc(sizeof(BufferDescription));
  if (NULL == pNewDescription) {
    return DfIoStatusError(ADMXRC2_NO_MEMORY);
  }
  pNewDescription->handle = 0;
  pNewDescription->refCount = 0;
  pNewDescription->pCreator = pClCtx;
  pNewDescription->pDfBufferDesc = NULL;

  if (!getHandle(pDevCtx, pClCtx, &handle)) {
    dfFree(pNewDescription);
    return DfIoStatusError(ADMXRC2_NO_DMADESC);
  }

  dfRes = dfBufferDescriptionCreateUser(pIoctl->in.pBuffer, pIoctl->in.length, &pDfBufferDesc);
  if (dfRes != DfDescriptionSuccess) {
    putHandle(pDevCtx, pClCtx, handle);
    dfFree(pNewDescription);
    switch (dfRes) {
    case DfDescriptionNullPointer:
      return DfIoStatusError(ADMXRC2_NULL_POINTER);

    case DfDescriptionInvalidBuffer:
      return DfIoStatusError(ADMXRC2_INVALID_PARAMETER); /* TODO - could return better status code here */

    case DfDescriptionNoMemory:
      return DfIoStatusError(ADMXRC2_NO_MEMORY);

    case DfDescriptionUnexpectedError:
    default:
      return DfIoStatusError(ADMXRC2_UNKNOWN_ERROR);
    }
  }
  pNewDescription->pDfBufferDesc = pDfBufferDesc;

  if (!dfBufferDescriptionLock(pDfBufferDesc)) {
    dfBufferDescriptionDelete(pDfBufferDesc);
    putHandle(pDevCtx, pClCtx, handle);
    dfFree(pNewDescription);
    return DfIoStatusError(ADMXRC2_INVALID_PARAMETER); /* TODO - could return better status code here */
  }

  bindHandle(pDevCtx, pClCtx, handle, pNewDescription);

  pIoctl->out.hBuffer = handle + 1;

  return DfIoStatusSuccess;
}

#if DF_NEED_THUNK
DfIoStatus
ioctlLockThunk(
	Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  void* pBuffer,
  unsigned int inSize,
  unsigned int outSize)
{
  IOCTLS32_ADMXRC2_LOCK* pIoctl32 = (IOCTLS32_ADMXRC2_LOCK*)pBuffer;
  IOCTLS_ADMXRC2_LOCK ioctl;
  DfIoStatus status;

  if (NULL == pIoctl32 || sizeof(pIoctl32->in) != inSize || sizeof(pIoctl32->out) != outSize) {
    return DfIoStatusInvalid;
  }

  ioctl.in.pBuffer = dfThunkPtr(pIoctl32->in.pBuffer);
  ioctl.in.length = (size_t)pIoctl32->in.length;
  status = ioctlLock(pDevCtx, pClCtx, &ioctl, sizeof(ioctl.in), outSize);
  if (status == DfIoStatusSuccess) {
    pIoctl32->out.hBuffer = ioctl.out.hBuffer;
  }
  return status;
}
#endif

DfIoStatus
ioctlUnlock(
	Admxrc2DeviceContext* pDevCtx,
  Admxrc2ClientContext* pClCtx,
  void* pBuffer,
  unsigned int inSize)
{
  IOCTLS_ADMXRC2_UNLOCK* pIoctl = (IOCTLS_ADMXRC2_UNLOCK*)pBuffer;
  DfIoStatus status = DfIoStatusSuccess;
  uint32_t handle;
  BufferDescription* pDescription;

  if (NULL == pIoctl || sizeof(pIoctl->in) != inSize) {
    return DfIoStatusInvalid;
  }

  handle = pIoctl->in.hBuffer - 1;
  pDescription = referenceBufferHandle(pDevCtx, pClCtx, handle);
  if (NULL == pDescription) {
    return DfIoStatusError(ADMXRC2_INVALID_DMADESC);
  }
  if (pDescription->pCreator != pClCtx) {
    status = DfIoStatusError(ADMXRC2_NOT_OWNER);
  } else {
    unbindHandle(pDevCtx, pClCtx, pDescription);
  }
  dereferenceBufferDescription(pDevCtx, pClCtx, pDescription);

  return status;
}

#if defined(ADMXRC2_LOCKED_BUFFERS_GLOBAL)
void
initLockedBuffers(
	DfDriverObject* pDrvObj)
{
  unsigned int i;

  /* Initialize DMA buffer stuff */
  dfSpinLockInit(&g_lockedBuffers.lock);
  g_lockedBuffers.stack.top = 0;
  for (i = 0; i < MAX_NUM_LOCKED_USER_SPACE_BUFFER; i++) {
    /* Handle actually returned to user-space app is 'i + 1' */
    g_lockedBuffers.stack.free[i] = i;
  }
}
#endif
